diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index ee16c0f8df..5b20345e37 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -5,6 +5,8 @@ on: - cron: "0 6 * * 1" # Weekly Monday 6 AM UTC workflow_dispatch: pull_request: + branches: + - main paths: - "src/channels/web/**" - "tests/e2e/**" @@ -50,9 +52,11 @@ jobs: - group: core files: "tests/e2e/scenarios/test_connection.py tests/e2e/scenarios/test_chat.py tests/e2e/scenarios/test_sse_reconnect.py tests/e2e/scenarios/test_html_injection.py tests/e2e/scenarios/test_csp.py" - group: features - files: "tests/e2e/scenarios/test_skills.py tests/e2e/scenarios/test_tool_approval.py" + files: "tests/e2e/scenarios/test_skills.py tests/e2e/scenarios/test_tool_approval.py tests/e2e/scenarios/test_webhook.py" - group: extensions - files: "tests/e2e/scenarios/test_extensions.py tests/e2e/scenarios/test_extension_oauth.py tests/e2e/scenarios/test_telegram_token_validation.py tests/e2e/scenarios/test_wasm_lifecycle.py tests/e2e/scenarios/test_tool_execution.py tests/e2e/scenarios/test_pairing.py tests/e2e/scenarios/test_oauth_credential_fallback.py tests/e2e/scenarios/test_routine_oauth_credential_injection.py" + files: "tests/e2e/scenarios/test_extensions.py tests/e2e/scenarios/test_extension_oauth.py tests/e2e/scenarios/test_telegram_token_validation.py tests/e2e/scenarios/test_telegram_hot_activation.py tests/e2e/scenarios/test_wasm_lifecycle.py tests/e2e/scenarios/test_tool_execution.py tests/e2e/scenarios/test_pairing.py tests/e2e/scenarios/test_mcp_auth_flow.py tests/e2e/scenarios/test_oauth_credential_fallback.py tests/e2e/scenarios/test_routine_oauth_credential_injection.py" + - group: routines + files: "tests/e2e/scenarios/test_owner_scope.py tests/e2e/scenarios/test_routine_event_batch.py" steps: - uses: actions/checkout@v6 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c3ceb8b61c..7946c3535c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: matrix: include: - name: all-features - flags: "--features postgres,libsql,html-to-markdown" + flags: "--all-features" - name: default flags: "" - name: libsql-only diff --git a/FEATURE_PARITY.md b/FEATURE_PARITY.md index 0cda8caaac..d00ff5e5df 100644 --- a/FEATURE_PARITY.md +++ b/FEATURE_PARITY.md @@ -20,9 +20,9 @@ This document tracks feature parity between IronClaw (Rust implementation) and O |---------|----------|----------|-------| | Hub-and-spoke architecture | ✅ | ✅ | Web gateway as central hub | | WebSocket control plane | ✅ | ✅ | Gateway with WebSocket + SSE | -| Single-user system | ✅ | ✅ | | +| Single-user system | ✅ | ✅ | Explicit instance owner scope for persistent routines, secrets, jobs, settings, extensions, and workspace memory | | Multi-agent routing | ✅ | ❌ | Workspace isolation per-agent | -| Session-based messaging | ✅ | ✅ | Per-sender sessions | +| Session-based messaging | ✅ | ✅ | Owner scope is separate from sender identity and conversation scope | | Loopback-first networking | ✅ | ✅ | HTTP binds to 0.0.0.0 but can be configured | ### Owner: _Unassigned_ @@ -66,9 +66,9 @@ This document tracks feature parity between IronClaw (Rust implementation) and O | CLI/TUI | ✅ | ✅ | - | Ratatui-based TUI | | HTTP webhook | ✅ | ✅ | - | axum with secret validation | | REPL (simple) | ✅ | ✅ | - | For testing | -| WASM channels | ❌ | ✅ | - | IronClaw innovation | +| WASM channels | ❌ | ✅ | - | IronClaw innovation; host resolves owner scope vs sender identity | | WhatsApp | ✅ | ❌ | P1 | Baileys (Web), same-phone mode with echo detection | -| Telegram | ✅ | ✅ | - | WASM channel(MTProto), DM pairing, caption, /start, bot_username, DM topics, setup-time owner verification | +| Telegram | ✅ | ✅ | - | WASM channel(MTProto), DM pairing, caption, /start, bot_username, DM topics, setup-time owner verification, owner-scoped persistence | | Discord | ✅ | ❌ | P2 | discord.js, thread parent binding inheritance | | Signal | ✅ | ✅ | P2 | signal-cli daemonPC, SSE listener HTTP/JSON-R, user/group allowlists, DM pairing | | Slack | ✅ | ✅ | - | WASM tool | diff --git a/channels-src/telegram/src/lib.rs b/channels-src/telegram/src/lib.rs index 936197bc04..a095ccb3a2 100644 --- a/channels-src/telegram/src/lib.rs +++ b/channels-src/telegram/src/lib.rs @@ -102,7 +102,6 @@ struct TelegramMessage { sticker: Option, /// Forum topic ID. Present when the message is sent inside a forum topic. - /// https://core.telegram.org/bots/api#message #[serde(default)] message_thread_id: Option, @@ -207,10 +206,6 @@ struct TelegramChat { /// Title for groups/channels. title: Option, - /// True when the supergroup has topics (forum mode) enabled. - #[serde(default)] - is_forum: Option, - /// Username for private chats. username: Option, } @@ -508,8 +503,7 @@ impl Guest for TelegramChannel { // Delete any existing webhook before polling. Telegram returns success // when no webhook exists, so any error here (e.g. 401) means a bad token. - delete_webhook() - .map_err(|e| format!("Bot token validation failed: {}", e))?; + delete_webhook().map_err(|e| format!("Bot token validation failed: {}", e))?; } // Configure polling only if not in webhook mode @@ -697,7 +691,12 @@ impl Guest for TelegramChannel { let metadata: TelegramMessageMetadata = serde_json::from_str(&response.metadata_json) .map_err(|e| format!("Failed to parse metadata: {}", e))?; - send_response(metadata.chat_id, &response, Some(metadata.message_id), metadata.message_thread_id) + send_response( + metadata.chat_id, + &response, + Some(metadata.message_id), + metadata.message_thread_id, + ) } fn on_broadcast(user_id: String, response: AgentResponse) -> Result<(), String> { @@ -734,8 +733,6 @@ impl Guest for TelegramChannel { "action": "typing" }); - // sendChatAction requires message_thread_id even for the General - // topic (id=1), unlike sendMessage which rejects it. if let Some(thread_id) = metadata.message_thread_id { payload["message_thread_id"] = serde_json::Value::Number(thread_id.into()); } @@ -766,9 +763,13 @@ impl Guest for TelegramChannel { } TelegramStatusAction::Notify(prompt) => { // Send user-visible status updates for actionable events. - if let Err(first_err) = - send_message(metadata.chat_id, &prompt, Some(metadata.message_id), None, metadata.message_thread_id) - { + if let Err(first_err) = send_message( + metadata.chat_id, + &prompt, + Some(metadata.message_id), + None, + metadata.message_thread_id, + ) { channel_host::log( channel_host::LogLevel::Warn, &format!( @@ -777,7 +778,13 @@ impl Guest for TelegramChannel { ), ); - if let Err(retry_err) = send_message(metadata.chat_id, &prompt, None, None, metadata.message_thread_id) { + if let Err(retry_err) = send_message( + metadata.chat_id, + &prompt, + None, + None, + metadata.message_thread_id, + ) { channel_host::log( channel_host::LogLevel::Debug, &format!( @@ -822,9 +829,8 @@ impl std::fmt::Display for SendError { /// Normalize `message_thread_id` for outbound API calls. /// -/// Telegram rejects `sendMessage` (and other send methods) when -/// `message_thread_id = 1` (the "General" topic). Return `None` in that -/// case so the field is omitted from the payload. +/// Telegram rejects `sendMessage` and file-send methods when +/// `message_thread_id = 1` (the "General" topic), so omit it in that case. fn normalize_thread_id(thread_id: Option) -> Option { thread_id.filter(|&id| id != 1) } @@ -950,19 +956,20 @@ fn download_telegram_file(file_id: &str) -> Result, String> { ); let headers = serde_json::json!({}); - let result = - channel_host::http_request("GET", &get_file_url, &headers.to_string(), None, None); + let result = channel_host::http_request("GET", &get_file_url, &headers.to_string(), None, None); let response = result.map_err(|e| format!("getFile request failed: {}", e))?; if response.status != 200 { let body_str = String::from_utf8_lossy(&response.body); - return Err(format!("getFile returned {}: {}", response.status, body_str)); + return Err(format!( + "getFile returned {}: {}", + response.status, body_str + )); } - let api_response: TelegramApiResponse = - serde_json::from_slice(&response.body) - .map_err(|e| format!("Failed to parse getFile response: {}", e))?; + let api_response: TelegramApiResponse = serde_json::from_slice(&response.body) + .map_err(|e| format!("Failed to parse getFile response: {}", e))?; if !api_response.ok { return Err(format!( @@ -992,16 +999,12 @@ fn download_telegram_file(file_id: &str) -> Result, String> { file_path ); - let result = - channel_host::http_request("GET", &download_url, &headers.to_string(), None, None); + let result = channel_host::http_request("GET", &download_url, &headers.to_string(), None, None); let response = result.map_err(|e| format!("File download failed: {}", e))?; if response.status != 200 { - return Err(format!( - "File download returned status {}", - response.status - )); + return Err(format!("File download returned status {}", response.status)); } // Post-download size guard: Telegram metadata file_size is optional, @@ -1088,7 +1091,14 @@ fn send_photo( data.len() ), ); - return send_document(chat_id, filename, mime_type, data, reply_to_message_id, message_thread_id); + return send_document( + chat_id, + filename, + mime_type, + data, + reply_to_message_id, + message_thread_id, + ); } let boundary = format!("ironclaw-{}", channel_host::now_millis()); @@ -1096,10 +1106,20 @@ fn send_photo( write_multipart_field(&mut body, &boundary, "chat_id", &chat_id.to_string()); if let Some(msg_id) = reply_to_message_id { - write_multipart_field(&mut body, &boundary, "reply_to_message_id", &msg_id.to_string()); + write_multipart_field( + &mut body, + &boundary, + "reply_to_message_id", + &msg_id.to_string(), + ); } if let Some(thread_id) = message_thread_id { - write_multipart_field(&mut body, &boundary, "message_thread_id", &thread_id.to_string()); + write_multipart_field( + &mut body, + &boundary, + "message_thread_id", + &thread_id.to_string(), + ); } write_multipart_file(&mut body, &boundary, "photo", filename, mime_type, data); body.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes()); @@ -1151,10 +1171,20 @@ fn send_document( write_multipart_field(&mut body, &boundary, "chat_id", &chat_id.to_string()); if let Some(msg_id) = reply_to_message_id { - write_multipart_field(&mut body, &boundary, "reply_to_message_id", &msg_id.to_string()); + write_multipart_field( + &mut body, + &boundary, + "reply_to_message_id", + &msg_id.to_string(), + ); } if let Some(thread_id) = message_thread_id { - write_multipart_field(&mut body, &boundary, "message_thread_id", &thread_id.to_string()); + write_multipart_field( + &mut body, + &boundary, + "message_thread_id", + &thread_id.to_string(), + ); } write_multipart_file(&mut body, &boundary, "document", filename, mime_type, data); body.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes()); @@ -1191,12 +1221,7 @@ fn send_document( } /// Image MIME types that Telegram's sendPhoto API supports. -const PHOTO_MIME_TYPES: &[&str] = &[ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", -]; +const PHOTO_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"]; /// Send a full agent response (attachments + text) to a chat. /// @@ -1218,13 +1243,23 @@ fn send_response( } // Try Markdown, fall back to plain text on parse errors - match send_message(chat_id, &response.content, reply_to_message_id, Some("Markdown"), message_thread_id) { + match send_message( + chat_id, + &response.content, + reply_to_message_id, + Some("Markdown"), + message_thread_id, + ) { Ok(_) => Ok(()), - Err(SendError::ParseEntities(_)) => { - send_message(chat_id, &response.content, reply_to_message_id, None, message_thread_id) - .map(|_| ()) - .map_err(|e| format!("Plain-text retry also failed: {}", e)) - } + Err(SendError::ParseEntities(_)) => send_message( + chat_id, + &response.content, + reply_to_message_id, + None, + message_thread_id, + ) + .map(|_| ()) + .map_err(|e| format!("Plain-text retry also failed: {}", e)), Err(e) => Err(e.to_string()), } } @@ -1392,7 +1427,10 @@ fn register_webhook(tunnel_url: &str, webhook_secret: Option<&str>) -> Result<() let context = if retried { " (after retry)" } else { "" }; channel_host::log( channel_host::LogLevel::Info, - &format!("Webhook registered successfully{}: {}", context, webhook_url), + &format!( + "Webhook registered successfully{}: {}", + context, webhook_url + ), ); Ok(()) @@ -1412,7 +1450,7 @@ fn send_pairing_reply(chat_id: i64, code: &str) -> Result<(), String> { ), None, Some("Markdown"), - None, // Pairing happens in DMs, not forum topics + None, ) .map(|_| ()) .map_err(|e| e.to_string()) @@ -1494,7 +1532,9 @@ fn extract_attachments(message: &TelegramMessage) -> Vec { if let Some(ref doc) = message.document { attachments.push(make_inbound_attachment( doc.file_id.clone(), - doc.mime_type.clone().unwrap_or_else(|| "application/octet-stream".to_string()), + doc.mime_type + .clone() + .unwrap_or_else(|| "application/octet-stream".to_string()), doc.file_name.clone(), doc.file_size.map(|s| s as u64), Some(get_file_url(&doc.file_id)), @@ -1507,7 +1547,10 @@ fn extract_attachments(message: &TelegramMessage) -> Vec { if let Some(ref audio) = message.audio { attachments.push(make_inbound_attachment( audio.file_id.clone(), - audio.mime_type.clone().unwrap_or_else(|| "audio/mpeg".to_string()), + audio + .mime_type + .clone() + .unwrap_or_else(|| "audio/mpeg".to_string()), audio.file_name.clone(), audio.file_size.map(|s| s as u64), Some(get_file_url(&audio.file_id)), @@ -1520,7 +1563,10 @@ fn extract_attachments(message: &TelegramMessage) -> Vec { if let Some(ref video) = message.video { attachments.push(make_inbound_attachment( video.file_id.clone(), - video.mime_type.clone().unwrap_or_else(|| "video/mp4".to_string()), + video + .mime_type + .clone() + .unwrap_or_else(|| "video/mp4".to_string()), video.file_name.clone(), video.file_size.map(|s| s as u64), Some(get_file_url(&video.file_id)), @@ -1745,25 +1791,14 @@ fn handle_message(message: TelegramMessage) { let is_private = message.chat.chat_type == "private"; - // Owner validation: when owner_id is set, only that user can message - let owner_id_str = channel_host::workspace_read(OWNER_ID_PATH).filter(|s| !s.is_empty()); + let owner_id = channel_host::workspace_read(OWNER_ID_PATH) + .filter(|s| !s.is_empty()) + .and_then(|s| s.parse::().ok()); + let is_owner = owner_id == Some(from.id); - if let Some(ref id_str) = owner_id_str { - if let Ok(owner_id) = id_str.parse::() { - if from.id != owner_id { - channel_host::log( - channel_host::LogLevel::Debug, - &format!( - "Dropping message from non-owner user {} (owner: {})", - from.id, owner_id - ), - ); - return; - } - } - } else { - // No owner_id: apply authorization based on dm_policy and allow_from - // This applies to both private and group chats when owner_id is null + if !is_owner { + // Non-owner senders remain guests. Apply authorization based on + // dm_policy / allow_from before letting them chat in their own scope. let dm_policy = channel_host::workspace_read(DM_POLICY_PATH).unwrap_or_else(|| "pairing".to_string()); @@ -1830,8 +1865,6 @@ fn handle_message(message: TelegramMessage) { } } - let bot_username = channel_host::workspace_read(BOT_USERNAME_PATH).unwrap_or_default(); - // For group chats, only respond if bot was mentioned or respond_to_all is enabled if !is_private { let respond_to_all = channel_host::workspace_read(RESPOND_TO_ALL_GROUP_PATH) @@ -1841,6 +1874,7 @@ fn handle_message(message: TelegramMessage) { if !respond_to_all { let has_command = content.starts_with('/'); + let bot_username = channel_host::workspace_read(BOT_USERNAME_PATH).unwrap_or_default(); let has_bot_mention = if bot_username.is_empty() { content.contains('@') } else { @@ -1876,18 +1910,7 @@ fn handle_message(message: TelegramMessage) { let metadata_json = serde_json::to_string(&metadata).unwrap_or_else(|_| "{}".to_string()); - // Compute thread_id for forum topics: "chat_id:topic_id" to prevent - // collisions across different groups (topic IDs are only unique per chat). - // Only use message_thread_id when the chat is a forum — non-forum groups - // also carry message_thread_id for reply threads, which are not topics. - let thread_id = if message.chat.is_forum == Some(true) { - message.message_thread_id.map(|topic_id| { - format!("{}:{}", message.chat.id, topic_id) - }) - } else { - None - }; - + let bot_username = channel_host::workspace_read(BOT_USERNAME_PATH).unwrap_or_default(); let content_to_emit = match content_to_emit_for_agent( &content, if bot_username.is_empty() { @@ -1907,7 +1930,7 @@ fn handle_message(message: TelegramMessage) { user_id: from.id.to_string(), user_name: Some(user_name), content: content_to_emit, - thread_id, + thread_id: Some(message.chat.id.to_string()), metadata_json, attachments, }); @@ -2507,7 +2530,11 @@ mod tests { assert_eq!(attachments[0].id, "large_id"); // Largest photo assert_eq!(attachments[0].mime_type, "image/jpeg"); assert_eq!(attachments[0].size_bytes, Some(54321)); - assert!(attachments[0].source_url.as_ref().unwrap().contains("large_id")); + assert!(attachments[0] + .source_url + .as_ref() + .unwrap() + .contains("large_id")); } #[test] @@ -2559,9 +2586,7 @@ mod tests { attachments[0].filename.as_deref(), Some("voice_voice_xyz.ogg") ); - assert!(attachments[0] - .extras_json - .contains("\"duration_secs\":5")); + assert!(attachments[0].extras_json.contains("\"duration_secs\":5")); } #[test] @@ -2707,18 +2732,33 @@ mod tests { }; // PDFs and Office docs should be downloaded - assert!(is_downloadable_document(&make("application/pdf", Some("report.pdf")))); + assert!(is_downloadable_document(&make( + "application/pdf", + Some("report.pdf") + ))); assert!(is_downloadable_document(&make( "application/vnd.openxmlformats-officedocument.wordprocessingml.document", Some("doc.docx"), ))); - assert!(is_downloadable_document(&make("text/plain", Some("notes.txt")))); + assert!(is_downloadable_document(&make( + "text/plain", + Some("notes.txt") + ))); // Voice, image, audio, video should NOT be downloaded - assert!(!is_downloadable_document(&make("audio/ogg", Some("voice_123.ogg")))); + assert!(!is_downloadable_document(&make( + "audio/ogg", + Some("voice_123.ogg") + ))); assert!(!is_downloadable_document(&make("image/jpeg", None))); - assert!(!is_downloadable_document(&make("audio/mpeg", Some("song.mp3")))); - assert!(!is_downloadable_document(&make("video/mp4", Some("clip.mp4")))); + assert!(!is_downloadable_document(&make( + "audio/mpeg", + Some("song.mp3") + ))); + assert!(!is_downloadable_document(&make( + "video/mp4", + Some("clip.mp4") + ))); } #[test] @@ -2726,100 +2766,4 @@ mod tests { // Verify the constant is 20 MB, matching the Slack channel limit assert_eq!(MAX_DOWNLOAD_SIZE_BYTES, 20 * 1024 * 1024); } - - // === Forum Topics (thread_id) tests === - - #[test] - fn test_parse_forum_message_with_thread_id() { - let json = r#"{ - "message_id": 100, - "message_thread_id": 42, - "is_topic_message": true, - "from": {"id": 1, "is_bot": false, "first_name": "A"}, - "chat": {"id": -1001234567890, "type": "supergroup", "is_forum": true}, - "text": "Hello from a topic" - }"#; - let msg: TelegramMessage = serde_json::from_str(json).unwrap(); - assert_eq!(msg.message_thread_id, Some(42)); - assert_eq!(msg.is_topic_message, Some(true)); - assert_eq!(msg.chat.is_forum, Some(true)); - } - - #[test] - fn test_parse_non_forum_message_backward_compat() { - let json = r#"{ - "message_id": 1, - "from": {"id": 1, "is_bot": false, "first_name": "A"}, - "chat": {"id": 1, "type": "private"}, - "text": "Hello" - }"#; - let msg: TelegramMessage = serde_json::from_str(json).unwrap(); - assert_eq!(msg.message_thread_id, None); - assert_eq!(msg.is_topic_message, None); - assert_eq!(msg.chat.is_forum, None); - } - - #[test] - fn test_metadata_with_message_thread_id() { - let metadata = TelegramMessageMetadata { - chat_id: -1001234567890, - message_id: 100, - user_id: 42, - is_private: false, - message_thread_id: Some(7), - }; - let json = serde_json::to_string(&metadata).unwrap(); - let parsed: TelegramMessageMetadata = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed.message_thread_id, Some(7)); - } - - #[test] - fn test_metadata_backward_compat_no_thread_id() { - // Old metadata JSON without message_thread_id should deserialize with None - let json = r#"{"chat_id":123,"message_id":1,"user_id":42,"is_private":true}"#; - let metadata: TelegramMessageMetadata = serde_json::from_str(json).unwrap(); - assert_eq!(metadata.message_thread_id, None); - } - - #[test] - fn test_metadata_thread_id_not_serialized_when_none() { - let metadata = TelegramMessageMetadata { - chat_id: 123, - message_id: 1, - user_id: 42, - is_private: true, - message_thread_id: None, - }; - let json = serde_json::to_string(&metadata).unwrap(); - assert!(!json.contains("message_thread_id")); - } - - #[test] - fn test_thread_id_composition() { - // Verify "chat_id:topic_id" format for forum topics - let chat_id: i64 = -1001234567890; - let topic_id: i64 = 42; - let thread_id = format!("{}:{}", chat_id, topic_id); - assert_eq!(thread_id, "-1001234567890:42"); - } - - #[test] - fn test_normalize_thread_id_general_topic() { - // General topic (id=1) must be omitted — Telegram rejects sendMessage - // with message_thread_id=1. - assert_eq!(normalize_thread_id(Some(1)), None); - } - - #[test] - fn test_normalize_thread_id_regular_topic() { - // Non-General topics pass through unchanged - assert_eq!(normalize_thread_id(Some(42)), Some(42)); - assert_eq!(normalize_thread_id(Some(123)), Some(123)); - } - - #[test] - fn test_normalize_thread_id_none() { - // None stays None - assert_eq!(normalize_thread_id(None), None); - } } diff --git a/migrations/V13__owner_scope_notify_targets.sql b/migrations/V13__owner_scope_notify_targets.sql new file mode 100644 index 0000000000..4c7064fab6 --- /dev/null +++ b/migrations/V13__owner_scope_notify_targets.sql @@ -0,0 +1,11 @@ +-- Remove the legacy 'default' sentinel from routine notifications. +-- A NULL notify_user now means "resolve the configured owner's last-seen +-- channel target at send time." + +ALTER TABLE routines + ALTER COLUMN notify_user DROP NOT NULL, + ALTER COLUMN notify_user DROP DEFAULT; + +UPDATE routines +SET notify_user = NULL +WHERE notify_user = 'default'; diff --git a/migrations/V6__routines.sql b/migrations/V6__routines.sql index 36f63cb2f5..9697251cc9 100644 --- a/migrations/V6__routines.sql +++ b/migrations/V6__routines.sql @@ -26,7 +26,7 @@ CREATE TABLE routines ( -- Notification preferences notify_channel TEXT, -- NULL = use default - notify_user TEXT NOT NULL DEFAULT 'default', + notify_user TEXT, notify_on_success BOOLEAN NOT NULL DEFAULT false, notify_on_failure BOOLEAN NOT NULL DEFAULT true, notify_on_attention BOOLEAN NOT NULL DEFAULT true, diff --git a/registry/channels/discord.json b/registry/channels/discord.json index 6f5cd4e7e4..50ef85ee0a 100644 --- a/registry/channels/discord.json +++ b/registry/channels/discord.json @@ -2,7 +2,7 @@ "name": "discord", "display_name": "Discord Channel", "kind": "channel", - "version": "0.2.0", + "version": "0.2.1", "wit_version": "0.3.0", "description": "Talk to your agent in Discord", "keywords": [ diff --git a/registry/tools/github.json b/registry/tools/github.json index e84f756dcf..e775ac8216 100644 --- a/registry/tools/github.json +++ b/registry/tools/github.json @@ -2,7 +2,7 @@ "name": "github", "display_name": "GitHub", "kind": "tool", - "version": "0.2.0", + "version": "0.2.1", "wit_version": "0.3.0", "description": "GitHub integration for issues, PRs, repos, and code search", "keywords": [ diff --git a/registry/tools/web-search.json b/registry/tools/web-search.json index 4da5744b01..1722c39187 100644 --- a/registry/tools/web-search.json +++ b/registry/tools/web-search.json @@ -2,7 +2,7 @@ "name": "web-search", "display_name": "Web Search", "kind": "tool", - "version": "0.2.0", + "version": "0.2.1", "wit_version": "0.3.0", "description": "Search the web using Brave Search API", "keywords": [ diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 4b7ed5381f..aaaad879d1 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -22,7 +22,7 @@ use crate::channels::{ChannelManager, IncomingMessage, OutgoingResponse}; use crate::config::{AgentConfig, HeartbeatConfig, RoutineConfig, SkillsConfig}; use crate::context::ContextManager; use crate::db::Database; -use crate::error::Error; +use crate::error::{ChannelError, Error}; use crate::extensions::ExtensionManager; use crate::hooks::HookRegistry; use crate::llm::LlmProvider; @@ -54,10 +54,26 @@ pub(crate) fn truncate_for_preview(output: &str, max_chars: usize) -> String { } } +fn resolve_routine_notification_user(metadata: &serde_json::Value) -> Option { + metadata + .get("notify_user") + .and_then(|value| value.as_str()) + .or_else(|| metadata.get("owner_id").and_then(|value| value.as_str())) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) +} + +fn should_fallback_routine_notification(error: &ChannelError) -> bool { + !matches!(error, ChannelError::MissingRoutingTarget { .. }) +} + /// Core dependencies for the agent. /// /// Bundles the shared components to reduce argument count. pub struct AgentDeps { + /// Resolved durable owner scope for the instance. + pub owner_id: String, pub store: Option>, pub llm: Arc, /// Cheap/fast LLM for lightweight tasks (heartbeat, routing, evaluation). @@ -102,6 +118,18 @@ pub struct Agent { } impl Agent { + pub(super) fn owner_id(&self) -> &str { + if let Some(workspace) = self.deps.workspace.as_ref() { + debug_assert_eq!( + workspace.user_id(), + self.deps.owner_id, + "workspace.user_id() must stay aligned with deps.owner_id" + ); + } + + &self.deps.owner_id + } + /// Create a new agent. /// /// Optionally accepts pre-created `ContextManager` and `SessionManager` for sharing @@ -264,6 +292,7 @@ impl Agent { )); let repair_interval = self.config.repair_check_interval; let repair_channels = self.channels.clone(); + let repair_owner_id = self.owner_id().to_string(); let repair_handle = tokio::spawn(async move { loop { tokio::time::sleep(repair_interval).await; @@ -311,7 +340,9 @@ impl Agent { if let Some(msg) = notification { let response = OutgoingResponse::text(format!("Self-Repair: {}", msg)); - let _ = repair_channels.broadcast_all("default", response).await; + let _ = repair_channels + .broadcast_all(&repair_owner_id, response) + .await; } } @@ -325,7 +356,9 @@ impl Agent { "Self-Repair: Tool '{}' repaired: {}", tool.name, message )); - let _ = repair_channels.broadcast_all("default", response).await; + let _ = repair_channels + .broadcast_all(&repair_owner_id, response) + .await; } Ok(result) => { tracing::info!("Tool repair result: {:?}", result); @@ -362,9 +395,11 @@ impl Agent { .timezone .clone() .or_else(|| Some(self.config.default_timezone.clone())); - if let (Some(user), Some(channel)) = - (&hb_config.notify_user, &hb_config.notify_channel) - { + if let Some(channel) = &hb_config.notify_channel { + let user = hb_config + .notify_user + .clone() + .unwrap_or_else(|| self.owner_id().to_string()); config = config.with_notify(user, channel); } @@ -374,17 +409,18 @@ impl Agent { // Spawn notification forwarder that routes through channel manager let notify_channel = hb_config.notify_channel.clone(); - let notify_user = hb_config.notify_user.clone(); + let notify_user = hb_config + .notify_user + .clone() + .unwrap_or_else(|| self.owner_id().to_string()); let channels = self.channels.clone(); tokio::spawn(async move { while let Some(response) = notify_rx.recv().await { - let user = notify_user.as_deref().unwrap_or("default"); - // Try the configured channel first, fall back to // broadcasting on all channels. let targeted_ok = if let Some(ref channel) = notify_channel { channels - .broadcast(channel, user, response.clone()) + .broadcast(channel, ¬ify_user, response.clone()) .await .is_ok() } else { @@ -392,7 +428,7 @@ impl Agent { }; if !targeted_ok { - let results = channels.broadcast_all(user, response).await; + let results = channels.broadcast_all(¬ify_user, response).await; for (ch, result) in results { if let Err(e) = result { tracing::warn!( @@ -462,25 +498,41 @@ impl Agent { let channels = self.channels.clone(); tokio::spawn(async move { while let Some(response) = notify_rx.recv().await { - let user = response - .metadata - .get("notify_user") - .and_then(|v| v.as_str()) - .unwrap_or("default") - .to_string(); let notify_channel = response .metadata .get("notify_channel") .and_then(|v| v.as_str()) .map(|s| s.to_string()); + let Some(user) = resolve_routine_notification_user(&response.metadata) + else { + tracing::warn!( + notify_channel = ?notify_channel, + "Skipping routine notification with no explicit target or owner scope" + ); + continue; + }; // Try the configured channel first, fall back to // broadcasting on all channels. let targeted_ok = if let Some(ref channel) = notify_channel { - channels - .broadcast(channel, &user, response.clone()) - .await - .is_ok() + match channels.broadcast(channel, &user, response.clone()).await { + Ok(()) => true, + Err(e) => { + let should_fallback = + should_fallback_routine_notification(&e); + tracing::warn!( + channel = %channel, + user = %user, + error = %e, + should_fallback, + "Failed to send routine notification to configured channel" + ); + if !should_fallback { + continue; + } + false + } + } } else { false }; @@ -768,10 +820,7 @@ impl Agent { // For Signal, use signal_target from metadata (group:ID or phone number), // otherwise fall back to user_id let target = message - .metadata - .get("signal_target") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) + .routing_target() .unwrap_or_else(|| message.user_id.clone()); self.tools() .set_message_tool_context(Some(message.channel.clone()), Some(target)) @@ -811,7 +860,7 @@ impl Agent { } // Hydrate thread from DB if it's a historical thread not in memory - if let Some(ref external_thread_id) = message.thread_id { + if let Some(external_thread_id) = message.conversation_scope() { tracing::trace!( message_id = %message.id, thread_id = %external_thread_id, @@ -832,7 +881,7 @@ impl Agent { .resolve_thread( &message.user_id, &message.channel, - message.thread_id.as_deref(), + message.conversation_scope(), ) .await; tracing::debug!( @@ -985,7 +1034,11 @@ impl Agent { #[cfg(test)] mod tests { - use super::truncate_for_preview; + use super::{ + resolve_routine_notification_user, should_fallback_routine_notification, + truncate_for_preview, + }; + use crate::error::ChannelError; #[test] fn test_truncate_short_input() { @@ -1048,4 +1101,55 @@ mod tests { // 'h','e','l','l','o',' ','世','界' = 8 chars assert_eq!(result, "hello 世界..."); } + + #[test] + fn resolve_routine_notification_user_prefers_explicit_target() { + let metadata = serde_json::json!({ + "notify_user": "12345", + "owner_id": "owner-scope", + }); + + let resolved = resolve_routine_notification_user(&metadata); + assert_eq!(resolved.as_deref(), Some("12345")); // safety: test-only assertion + } + + #[test] + fn resolve_routine_notification_user_falls_back_to_owner_scope() { + let metadata = serde_json::json!({ + "notify_user": null, + "owner_id": "owner-scope", + }); + + let resolved = resolve_routine_notification_user(&metadata); + assert_eq!(resolved.as_deref(), Some("owner-scope")); // safety: test-only assertion + } + + #[test] + fn resolve_routine_notification_user_rejects_missing_values() { + let metadata = serde_json::json!({ + "notify_user": " ", + }); + + assert_eq!(resolve_routine_notification_user(&metadata), None); // safety: test-only assertion + } + + #[test] + fn targeted_routine_notifications_do_not_fallback_without_owner_route() { + let error = ChannelError::MissingRoutingTarget { + name: "telegram".to_string(), + reason: "No stored owner routing target for channel 'telegram'.".to_string(), + }; + + assert!(!should_fallback_routine_notification(&error)); // safety: test-only assertion + } + + #[test] + fn targeted_routine_notifications_may_fallback_for_other_errors() { + let error = ChannelError::SendFailed { + name: "telegram".to_string(), + reason: "timeout talking to channel".to_string(), + }; + + assert!(should_fallback_routine_notification(&error)); // safety: test-only assertion + } } diff --git a/src/agent/commands.rs b/src/agent/commands.rs index 90266d0bab..75c99359b5 100644 --- a/src/agent/commands.rs +++ b/src/agent/commands.rs @@ -836,7 +836,10 @@ impl Agent { // 1. Persist to DB if available. if let Some(store) = self.store() { let value = serde_json::Value::String(model.to_string()); - if let Err(e) = store.set_setting("default", "selected_model", &value).await { + if let Err(e) = store + .set_setting(self.owner_id(), "selected_model", &value) + .await + { tracing::warn!("Failed to persist model to DB: {}", e); } } diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs index 9e6747f2b3..9be0d654d1 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher.rs @@ -140,13 +140,15 @@ impl Agent { // Create a JobContext for tool execution (chat doesn't have a real job) let mut job_ctx = - JobContext::with_user(&message.user_id, "chat", "Interactive chat session"); + JobContext::with_user(&message.user_id, "chat", "Interactive chat session") + .with_requester_id(&message.sender_id); job_ctx.http_interceptor = self.deps.http_interceptor.clone(); job_ctx.user_timezone = user_tz.name().to_string(); job_ctx.metadata = serde_json::json!({ "notify_channel": message.channel, "notify_user": message.user_id, "notify_thread_id": message.thread_id, + "notify_metadata": message.metadata, }); // Build system prompts once for this turn. Two variants: with tools @@ -1175,6 +1177,7 @@ mod tests { /// Build a minimal `Agent` for unit testing (no DB, no workspace, no extensions). fn make_test_agent() -> Agent { let deps = AgentDeps { + owner_id: "default".to_string(), store: None, llm: Arc::new(StaticLlmProvider), cheap_llm: None, @@ -2014,6 +2017,7 @@ mod tests { /// `max_tool_iterations` override. fn make_test_agent_with_llm(llm: Arc, max_tool_iterations: usize) -> Agent { let deps = AgentDeps { + owner_id: "default".to_string(), store: None, llm, cheap_llm: None, @@ -2127,6 +2131,7 @@ mod tests { let max_iter = 3; let agent = { let deps = AgentDeps { + owner_id: "default".to_string(), store: None, llm, cheap_llm: None, diff --git a/src/agent/heartbeat.rs b/src/agent/heartbeat.rs index 77bdeadb0f..ec4cd5e9ec 100644 --- a/src/agent/heartbeat.rs +++ b/src/agent/heartbeat.rs @@ -402,7 +402,11 @@ impl HeartbeatRunner { return; }; - let user_id = self.config.notify_user_id.as_deref().unwrap_or("default"); + let user_id = self + .config + .notify_user_id + .as_deref() + .unwrap_or_else(|| self.workspace.user_id()); // Persist to heartbeat conversation and get thread_id let thread_id = if let Some(ref store) = self.store { @@ -431,6 +435,7 @@ impl HeartbeatRunner { attachments: Vec::new(), metadata: serde_json::json!({ "source": "heartbeat", + "owner_id": self.workspace.user_id(), }), }; diff --git a/src/agent/routine.rs b/src/agent/routine.rs index 0389ac1e33..f3850fa0b1 100644 --- a/src/agent/routine.rs +++ b/src/agent/routine.rs @@ -422,8 +422,8 @@ impl Default for RoutineGuardrails { pub struct NotifyConfig { /// Channel to notify on (None = default/broadcast all). pub channel: Option, - /// User to notify. - pub user: String, + /// Explicit target to notify. None means "resolve the owner's last-seen target". + pub user: Option, /// Notify when routine produces actionable output. pub on_attention: bool, /// Notify when routine errors. @@ -436,7 +436,7 @@ impl Default for NotifyConfig { fn default() -> Self { Self { channel: None, - user: "default".to_string(), + user: None, on_attention: true, on_failure: true, on_success: false, diff --git a/src/agent/routine_engine.rs b/src/agent/routine_engine.rs index c37ba7ce16..519f16c22a 100644 --- a/src/agent/routine_engine.rs +++ b/src/agent/routine_engine.rs @@ -172,6 +172,11 @@ impl RoutineEngine { EventMatcher::Message { routine, regex } => (routine, regex), EventMatcher::System { .. } => continue, }; + + if routine.user_id != message.user_id { + continue; + } + // Channel filter if let Trigger::Event { channel: Some(ch), .. @@ -650,6 +655,7 @@ async fn execute_routine(ctx: EngineContext, routine: Routine, run: RoutineRun) send_notification( &ctx.notify_tx, &routine.notify, + &routine.user_id, &routine.name, status, summary.as_deref(), @@ -694,7 +700,8 @@ async fn execute_full_job( reason: "scheduler not available".to_string(), })?; - let mut metadata = serde_json::json!({ "max_iterations": max_iterations }); + let mut metadata = + serde_json::json!({ "max_iterations": max_iterations, "owner_id": routine.user_id }); // Carry the routine's notify config in job metadata so the message tool // can resolve channel/target per-job without global state mutation. if let Some(channel) = &routine.notify.channel { @@ -1207,6 +1214,7 @@ async fn execute_routine_tool( async fn send_notification( tx: &mpsc::Sender, notify: &NotifyConfig, + owner_id: &str, routine_name: &str, status: RunStatus, summary: Option<&str>, @@ -1243,6 +1251,7 @@ async fn send_notification( "source": "routine", "routine_name": routine_name, "status": status.to_string(), + "owner_id": owner_id, "notify_user": notify.user, "notify_channel": notify.channel, }), diff --git a/src/agent/thread_ops.rs b/src/agent/thread_ops.rs index 7aa499aec0..e5f2005d25 100644 --- a/src/agent/thread_ops.rs +++ b/src/agent/thread_ops.rs @@ -924,7 +924,8 @@ impl Agent { // Execute the approved tool and continue the loop let mut job_ctx = - JobContext::with_user(&message.user_id, "chat", "Interactive chat session"); + JobContext::with_user(&message.user_id, "chat", "Interactive chat session") + .with_requester_id(&message.sender_id); job_ctx.http_interceptor = self.deps.http_interceptor.clone(); // Prefer a valid timezone from the approval message, fall back to the // resolved timezone stored when the approval was originally requested. diff --git a/src/app.rs b/src/app.rs index 00804de147..0ffe782064 100644 --- a/src/app.rs +++ b/src/app.rs @@ -140,12 +140,14 @@ impl AppBuilder { self.handles = Some(handles); // Post-init: migrate disk config, reload config from DB, attach session, cleanup - if let Err(e) = crate::bootstrap::migrate_disk_to_db(db.as_ref(), "default").await { + if let Err(e) = + crate::bootstrap::migrate_disk_to_db(db.as_ref(), &self.config.owner_id).await + { tracing::warn!("Disk-to-DB settings migration failed: {}", e); } let toml_path = self.toml_path.as_deref(); - match Config::from_db_with_toml(db.as_ref(), "default", toml_path).await { + match Config::from_db_with_toml(db.as_ref(), &self.config.owner_id, toml_path).await { Ok(db_config) => { self.config = db_config; tracing::debug!("Configuration reloaded from database"); @@ -158,7 +160,9 @@ impl AppBuilder { } } - self.session.attach_store(db.clone(), "default").await; + self.session + .attach_store(db.clone(), &self.config.owner_id) + .await; // Fire-and-forget housekeeping — no need to block startup. let db_cleanup = db.clone(); @@ -193,9 +197,10 @@ impl AppBuilder { let store: Option<&(dyn crate::db::SettingsStore + Sync)> = self.db.as_ref().map(|db| db.as_ref() as _); let toml_path = self.toml_path.as_deref(); + let owner_id = self.config.owner_id.clone(); if let Err(e) = self .config - .re_resolve_llm(store, "default", toml_path) + .re_resolve_llm(store, &owner_id, toml_path) .await { tracing::warn!( @@ -224,15 +229,17 @@ impl AppBuilder { if let Some(ref secrets) = store { // Inject LLM API keys from encrypted storage - crate::config::inject_llm_keys_from_secrets(secrets.as_ref(), "default").await; + crate::config::inject_llm_keys_from_secrets(secrets.as_ref(), &self.config.owner_id) + .await; // Re-resolve only the LLM config with newly available keys. let store: Option<&(dyn crate::db::SettingsStore + Sync)> = self.db.as_ref().map(|db| db.as_ref() as _); let toml_path = self.toml_path.as_deref(); + let owner_id = self.config.owner_id.clone(); if let Err(e) = self .config - .re_resolve_llm(store, "default", toml_path) + .re_resolve_llm(store, &owner_id, toml_path) .await { tracing::warn!("Failed to re-resolve LLM config after secret injection: {e}"); @@ -304,7 +311,7 @@ impl AppBuilder { // Register memory tools if database is available let workspace = if let Some(ref db) = self.db { - let mut ws = Workspace::new_with_db("default", db.clone()) + let mut ws = Workspace::new_with_db(&self.config.owner_id, db.clone()) .with_search_config(&self.config.search); if let Some(ref emb) = embeddings { ws = ws.with_embeddings(emb.clone()); @@ -469,9 +476,10 @@ impl AppBuilder { let tools = Arc::clone(tools); let mcp_sm = Arc::clone(&mcp_session_manager); let pm = Arc::clone(&mcp_process_manager); + let owner_id = self.config.owner_id.clone(); async move { let servers_result = if let Some(ref d) = db { - load_mcp_servers_from_db(d.as_ref(), "default").await + load_mcp_servers_from_db(d.as_ref(), &owner_id).await } else { crate::tools::mcp::config::load_mcp_servers().await }; @@ -491,6 +499,7 @@ impl AppBuilder { let secrets = secrets_store.clone(); let tools = Arc::clone(&tools); let pm = Arc::clone(&pm); + let owner_id = owner_id.clone(); join_set.spawn(async move { let server_name = server.name.clone(); @@ -500,7 +509,7 @@ impl AppBuilder { &mcp_sm, &pm, secrets, - "default", + &owner_id, ) .await { @@ -642,7 +651,7 @@ impl AppBuilder { self.config.wasm.tools_dir.clone(), self.config.channels.wasm_channels_dir.clone(), self.config.tunnel.public_url.clone(), - "default".to_string(), + self.config.owner_id.clone(), self.db.clone(), catalog_entries.clone(), )); diff --git a/src/channels/channel.rs b/src/channels/channel.rs index ed8c28ff2e..43e35688cc 100644 --- a/src/channels/channel.rs +++ b/src/channels/channel.rs @@ -67,14 +67,24 @@ pub struct IncomingMessage { pub id: Uuid, /// Channel this message came from. pub channel: String, - /// User identifier within the channel. + /// Storage/persistence scope for this interaction. + /// + /// For owner-capable channels this is the stable instance owner ID when the + /// configured owner is speaking; otherwise it can be a guest/sender-scoped + /// identifier to preserve isolation. pub user_id: String, + /// Stable instance owner scope for this IronClaw deployment. + pub owner_id: String, + /// Channel-specific sender/actor identifier. + pub sender_id: String, /// Optional display name. pub user_name: Option, /// Message content. pub content: String, /// Thread/conversation ID for threaded conversations. pub thread_id: Option, + /// Stable channel/chat/thread scope for this conversation. + pub conversation_scope_id: Option, /// When the message was received. pub received_at: DateTime, /// Channel-specific metadata. @@ -84,9 +94,8 @@ pub struct IncomingMessage { /// File or media attachments on this message. pub attachments: Vec, /// Internal-only flag: message was generated inside the process (e.g. job - /// monitor) and must bypass the normal user-input pipeline. This field is - /// **not** settable via `with_metadata()` — only trusted code paths inside - /// the binary can set it, preventing external channels from spoofing it. + /// monitor) and must bypass the normal user-input pipeline. This field is + /// not settable via metadata, so external channels cannot spoof it. pub(crate) is_internal: bool, } @@ -97,13 +106,17 @@ impl IncomingMessage { user_id: impl Into, content: impl Into, ) -> Self { + let user_id = user_id.into(); Self { id: Uuid::new_v4(), channel: channel.into(), - user_id: user_id.into(), + owner_id: user_id.clone(), + sender_id: user_id.clone(), + user_id, user_name: None, content: content.into(), thread_id: None, + conversation_scope_id: None, received_at: Utc::now(), metadata: serde_json::Value::Null, timezone: None, @@ -114,7 +127,27 @@ impl IncomingMessage { /// Set the thread ID. pub fn with_thread(mut self, thread_id: impl Into) -> Self { - self.thread_id = Some(thread_id.into()); + let thread_id = thread_id.into(); + self.conversation_scope_id = Some(thread_id.clone()); + self.thread_id = Some(thread_id); + self + } + + /// Set the stable owner scope for this message. + pub fn with_owner_id(mut self, owner_id: impl Into) -> Self { + self.owner_id = owner_id.into(); + self + } + + /// Set the channel-specific sender/actor identifier. + pub fn with_sender_id(mut self, sender_id: impl Into) -> Self { + self.sender_id = sender_id.into(); + self + } + + /// Set the conversation scope for this message. + pub fn with_conversation_scope(mut self, scope_id: impl Into) -> Self { + self.conversation_scope_id = Some(scope_id.into()); self } @@ -147,6 +180,49 @@ impl IncomingMessage { self.is_internal = true; self } + + /// Effective conversation scope, falling back to thread_id for legacy callers. + pub fn conversation_scope(&self) -> Option<&str> { + self.conversation_scope_id + .as_deref() + .or(self.thread_id.as_deref()) + } + + /// Best-effort routing target for proactive replies on the current channel. + pub fn routing_target(&self) -> Option { + routing_target_from_metadata(&self.metadata).or_else(|| { + if self.sender_id.is_empty() { + None + } else { + Some(self.sender_id.clone()) + } + }) + } +} + +/// Extract a channel-specific proactive routing target from message metadata. +pub fn routing_target_from_metadata(metadata: &serde_json::Value) -> Option { + metadata + .get("signal_target") + .and_then(|value| match value { + serde_json::Value::String(s) => Some(s.clone()), + serde_json::Value::Number(n) => Some(n.to_string()), + _ => None, + }) + .or_else(|| { + metadata.get("chat_id").and_then(|value| match value { + serde_json::Value::String(s) => Some(s.clone()), + serde_json::Value::Number(n) => Some(n.to_string()), + _ => None, + }) + }) + .or_else(|| { + metadata.get("target").and_then(|value| match value { + serde_json::Value::String(s) => Some(s.clone()), + serde_json::Value::Number(n) => Some(n.to_string()), + _ => None, + }) + }) } /// Stream of incoming messages. diff --git a/src/channels/http.rs b/src/channels/http.rs index 5c173bf299..9f39f46e00 100644 --- a/src/channels/http.rs +++ b/src/channels/http.rs @@ -133,7 +133,8 @@ impl HttpChannel { #[derive(Debug, Deserialize)] struct WebhookRequest { - /// User or client identifier (ignored, user is fixed by server config). + /// Optional caller or client identifier for sender-scoped routing. + /// The channel owner/storage scope remains fixed by server config. #[serde(default)] user_id: Option, /// Message content. @@ -403,12 +404,38 @@ async fn process_authenticated_request( state: Arc, req: WebhookRequest, ) -> axum::response::Response { - let _ = req.user_id.as_ref().map(|user_id| { - tracing::debug!( - provided_user_id = %user_id, - "HTTP webhook request provided user_id, ignoring in favor of configured user_id" - ); - }); + let normalized_user_id = req + .user_id + .as_deref() + .map(str::trim) + .filter(|user_id| !user_id.is_empty()); + + match (req.user_id.as_deref(), normalized_user_id) { + (Some(raw_user_id), Some(user_id)) if raw_user_id != user_id => { + tracing::debug!( + provided_user_id = %raw_user_id, + normalized_sender_id = %user_id, + configured_owner_id = %state.user_id, + "HTTP webhook request provided user_id; trimming and using it as sender_id while keeping the configured owner scope" + ); + } + (Some(user_id), Some(_)) => { + tracing::debug!( + provided_user_id = %user_id, + configured_owner_id = %state.user_id, + "HTTP webhook request provided user_id; using it as sender_id while keeping the configured owner scope" + ); + } + (Some(raw_user_id), None) => { + tracing::debug!( + provided_user_id = %raw_user_id, + configured_owner_id = %state.user_id, + "HTTP webhook request provided a blank user_id; falling back to the configured owner scope for sender_id" + ); + } + (None, None) => {} + (None, Some(_)) => unreachable!("normalized user_id requires a raw user_id"), + } if req.content.len() > MAX_CONTENT_BYTES { return ( @@ -514,11 +541,13 @@ async fn process_authenticated_request( Vec::new() }; - let mut msg = IncomingMessage::new("http", &state.user_id, &req.content).with_metadata( - serde_json::json!({ + let sender_id = normalized_user_id.unwrap_or(&state.user_id).to_string(); + let mut msg = IncomingMessage::new("http", &state.user_id, &req.content) + .with_owner_id(&state.user_id) + .with_sender_id(sender_id) + .with_metadata(serde_json::json!({ "wait_for_response": wait_for_response, - }), - ); + })); if !attachments.is_empty() { msg = msg.with_attachments(attachments); @@ -682,6 +711,7 @@ mod tests { use axum::body::Body; use axum::http::{HeaderValue, Request}; use secrecy::SecretString; + use tokio_stream::StreamExt; use tower::ServiceExt; use super::*; @@ -820,6 +850,70 @@ mod tests { assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); } + #[tokio::test] + async fn webhook_blank_user_id_falls_back_to_owner_scope() { + let secret = "test-secret-123"; + let channel = test_channel(Some(secret)); + let mut stream = channel.start().await.unwrap(); + let app = channel.routes(); + + let body = serde_json::json!({ + "content": "hello", + "user_id": " " + }); + let body_bytes = serde_json::to_vec(&body).unwrap(); + let signature = compute_signature(secret, &body_bytes); + let req = Request::builder() + .method("POST") + .uri("/webhook") + .header("content-type", "application/json") + .header("x-hub-signature-256", signature) + .body(Body::from(body_bytes)) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .expect("timed out waiting for webhook message") + .expect("stream should yield a webhook message"); + assert_eq!(msg.sender_id, "http"); + assert_eq!(msg.owner_id, "http"); + } + + #[tokio::test] + async fn webhook_user_id_is_trimmed_before_becoming_sender_id() { + let secret = "test-secret-123"; + let channel = test_channel(Some(secret)); + let mut stream = channel.start().await.unwrap(); + let app = channel.routes(); + + let body = serde_json::json!({ + "content": "hello", + "user_id": " alice " + }); + let body_bytes = serde_json::to_vec(&body).unwrap(); + let signature = compute_signature(secret, &body_bytes); + let req = Request::builder() + .method("POST") + .uri("/webhook") + .header("content-type", "application/json") + .header("x-hub-signature-256", signature) + .body(Body::from(body_bytes)) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .expect("timed out waiting for webhook message") + .expect("stream should yield a webhook message"); + assert_eq!(msg.sender_id, "alice"); + assert_eq!(msg.owner_id, "http"); + } + /// Regression test for issue #869: RwLock read guard was held across /// tx.send(msg).await in `process_message()`, blocking shutdown() from /// acquiring the write lock when the channel buffer was full. diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 289b64c7be..c023069293 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -39,7 +39,7 @@ mod webhook_server; pub use channel::{ AttachmentKind, Channel, ChannelSecretUpdater, IncomingAttachment, IncomingMessage, - MessageStream, OutgoingResponse, StatusUpdate, + MessageStream, OutgoingResponse, StatusUpdate, routing_target_from_metadata, }; pub use http::{HttpChannel, HttpChannelState}; pub use manager::ChannelManager; diff --git a/src/channels/repl.rs b/src/channels/repl.rs index 230d5e92c2..40d669198c 100644 --- a/src/channels/repl.rs +++ b/src/channels/repl.rs @@ -200,6 +200,8 @@ fn format_json_params(params: &serde_json::Value, indent: &str) -> String { /// REPL channel with line editing and markdown rendering. pub struct ReplChannel { + /// Stable owner scope for this REPL instance. + user_id: String, /// Optional single message to send (for -m flag). single_message: Option, /// Debug mode flag (shared with input thread). @@ -213,7 +215,13 @@ pub struct ReplChannel { impl ReplChannel { /// Create a new REPL channel. pub fn new() -> Self { + Self::with_user_id("default") + } + + /// Create a new REPL channel for a specific owner scope. + pub fn with_user_id(user_id: impl Into) -> Self { Self { + user_id: user_id.into(), single_message: None, debug_mode: Arc::new(AtomicBool::new(false)), is_streaming: Arc::new(AtomicBool::new(false)), @@ -223,7 +231,13 @@ impl ReplChannel { /// Create a REPL channel that sends a single message and exits. pub fn with_message(message: String) -> Self { + Self::with_message_for_user("default", message) + } + + /// Create a REPL channel that sends a single message for a specific owner scope and exits. + pub fn with_message_for_user(user_id: impl Into, message: String) -> Self { Self { + user_id: user_id.into(), single_message: Some(message), debug_mode: Arc::new(AtomicBool::new(false)), is_streaming: Arc::new(AtomicBool::new(false)), @@ -292,6 +306,7 @@ impl Channel for ReplChannel { async fn start(&self) -> Result { let (tx, rx) = mpsc::channel(32); let single_message = self.single_message.clone(); + let user_id = self.user_id.clone(); let debug_mode = Arc::clone(&self.debug_mode); let suppress_banner = Arc::clone(&self.suppress_banner); let esc_interrupt_triggered_for_thread = Arc::new(AtomicBool::new(false)); @@ -301,11 +316,11 @@ impl Channel for ReplChannel { // Single message mode: send it and return if let Some(msg) = single_message { - let incoming = IncomingMessage::new("repl", "default", &msg).with_timezone(&sys_tz); + let incoming = IncomingMessage::new("repl", &user_id, &msg).with_timezone(&sys_tz); let _ = tx.blocking_send(incoming); // Ensure the agent exits after handling exactly one turn in -m mode, // even when other channels (gateway/http) are enabled. - let _ = tx.blocking_send(IncomingMessage::new("repl", "default", "/quit")); + let _ = tx.blocking_send(IncomingMessage::new("repl", &user_id, "/quit")); return; } @@ -366,7 +381,7 @@ impl Channel for ReplChannel { "/quit" | "/exit" => { // Forward shutdown command so the agent loop exits even // when other channels (e.g. web gateway) are still active. - let msg = IncomingMessage::new("repl", "default", "/quit") + let msg = IncomingMessage::new("repl", &user_id, "/quit") .with_timezone(&sys_tz); let _ = tx.blocking_send(msg); break; @@ -389,7 +404,7 @@ impl Channel for ReplChannel { } let msg = - IncomingMessage::new("repl", "default", line).with_timezone(&sys_tz); + IncomingMessage::new("repl", &user_id, line).with_timezone(&sys_tz); if tx.blocking_send(msg).is_err() { break; } @@ -397,14 +412,14 @@ impl Channel for ReplChannel { Err(ReadlineError::Interrupted) => { if esc_interrupt_triggered_for_thread.swap(false, Ordering::Relaxed) { // Esc: interrupt current operation and keep REPL open. - let msg = IncomingMessage::new("repl", "default", "/interrupt") + let msg = IncomingMessage::new("repl", &user_id, "/interrupt") .with_timezone(&sys_tz); if tx.blocking_send(msg).is_err() { break; } } else { // Ctrl+C (VINTR): request graceful shutdown. - let msg = IncomingMessage::new("repl", "default", "/quit") + let msg = IncomingMessage::new("repl", &user_id, "/quit") .with_timezone(&sys_tz); let _ = tx.blocking_send(msg); break; @@ -416,7 +431,7 @@ impl Channel for ReplChannel { // immediately — just drop the REPL thread silently so other // channels (gateway, telegram, …) keep running. if std::io::stdin().is_terminal() { - let msg = IncomingMessage::new("repl", "default", "/quit") + let msg = IncomingMessage::new("repl", &user_id, "/quit") .with_timezone(&sys_tz); let _ = tx.blocking_send(msg); } diff --git a/src/channels/wasm/loader.rs b/src/channels/wasm/loader.rs index c261193e7d..6329428fea 100644 --- a/src/channels/wasm/loader.rs +++ b/src/channels/wasm/loader.rs @@ -27,6 +27,7 @@ pub struct WasmChannelLoader { pairing_store: Arc, settings_store: Option>, secrets_store: Option>, + owner_scope_id: String, } impl WasmChannelLoader { @@ -35,12 +36,14 @@ impl WasmChannelLoader { runtime: Arc, pairing_store: Arc, settings_store: Option>, + owner_scope_id: impl Into, ) -> Self { Self { runtime, pairing_store, settings_store, secrets_store: None, + owner_scope_id: owner_scope_id.into(), } } @@ -149,6 +152,7 @@ impl WasmChannelLoader { self.runtime.clone(), prepared, capabilities, + self.owner_scope_id.clone(), config_json, self.pairing_store.clone(), self.settings_store.clone(), @@ -487,7 +491,8 @@ mod tests { async fn test_loader_invalid_name() { let config = WasmChannelRuntimeConfig::for_testing(); let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap()); - let loader = WasmChannelLoader::new(runtime, Arc::new(PairingStore::new()), None); + let loader = + WasmChannelLoader::new(runtime, Arc::new(PairingStore::new()), None, "default"); let dir = TempDir::new().unwrap(); let wasm_path = dir.path().join("test.wasm"); @@ -505,7 +510,8 @@ mod tests { async fn load_from_dir_returns_empty_when_dir_missing() { let config = WasmChannelRuntimeConfig::for_testing(); let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap()); - let loader = WasmChannelLoader::new(runtime, Arc::new(PairingStore::new()), None); + let loader = + WasmChannelLoader::new(runtime, Arc::new(PairingStore::new()), None, "default"); let dir = TempDir::new().unwrap(); let missing = dir.path().join("nonexistent_channels_dir"); diff --git a/src/channels/wasm/mod.rs b/src/channels/wasm/mod.rs index dba843417d..882709a967 100644 --- a/src/channels/wasm/mod.rs +++ b/src/channels/wasm/mod.rs @@ -69,7 +69,7 @@ //! let runtime = WasmChannelRuntime::new(config)?; //! //! // Load channels from directory -//! let loader = WasmChannelLoader::new(runtime); +//! let loader = WasmChannelLoader::new(runtime, pairing_store, settings_store, owner_scope_id); //! let channels = loader.load_from_dir(Path::new("~/.ironclaw/channels/")).await?; //! //! // Add to channel manager diff --git a/src/channels/wasm/router.rs b/src/channels/wasm/router.rs index 9b0f3da176..8005ccea56 100644 --- a/src/channels/wasm/router.rs +++ b/src/channels/wasm/router.rs @@ -672,6 +672,7 @@ mod tests { runtime, prepared, capabilities, + "default", "{}".to_string(), Arc::new(PairingStore::new()), None, diff --git a/src/channels/wasm/setup.rs b/src/channels/wasm/setup.rs index 9c0c3f33a4..2b9703dc6f 100644 --- a/src/channels/wasm/setup.rs +++ b/src/channels/wasm/setup.rs @@ -50,6 +50,7 @@ pub async fn setup_wasm_channels( Arc::clone(&runtime), Arc::clone(&pairing_store), settings_store.clone(), + config.owner_id.clone(), ); if let Some(secrets) = secrets_store { loader = loader.with_secrets_store(Arc::clone(secrets)); @@ -117,6 +118,11 @@ async fn register_channel( ) -> (String, Box) { let channel_name = loaded.name().to_string(); tracing::info!("Loaded WASM channel: {}", channel_name); + let owner_actor_id = config + .channels + .wasm_channel_owner_ids + .get(channel_name.as_str()) + .map(ToString::to_string); let secret_name = loaded.webhook_secret_name(); let sig_key_secret_name = loaded.signature_key_secret_name(); @@ -124,7 +130,7 @@ async fn register_channel( let webhook_secret = if let Some(secrets) = secrets_store { secrets - .get_decrypted("default", &secret_name) + .get_decrypted(&config.owner_id, &secret_name) .await .ok() .map(|s| s.expose().to_string()) @@ -142,7 +148,7 @@ async fn register_channel( require_secret: webhook_secret.is_some(), }]; - let channel_arc = Arc::new(loaded.channel); + let channel_arc = Arc::new(loaded.channel.with_owner_actor_id(owner_actor_id.clone())); // Inject runtime config (tunnel URL, webhook secret, owner_id). { @@ -216,7 +222,7 @@ async fn register_channel( // Register Ed25519 signature key if declared in capabilities. if let Some(ref sig_key_name) = sig_key_secret_name && let Some(secrets) = secrets_store - && let Ok(key_secret) = secrets.get_decrypted("default", sig_key_name).await + && let Ok(key_secret) = secrets.get_decrypted(&config.owner_id, sig_key_name).await { match wasm_router .register_signature_key(&channel_name, key_secret.expose()) @@ -234,7 +240,9 @@ async fn register_channel( // Register HMAC signing secret if declared in capabilities. if let Some(ref hmac_secret_name) = hmac_secret_name && let Some(secrets) = secrets_store - && let Ok(secret) = secrets.get_decrypted("default", hmac_secret_name).await + && let Ok(secret) = secrets + .get_decrypted(&config.owner_id, hmac_secret_name) + .await { wasm_router .register_hmac_secret(&channel_name, secret.expose()) @@ -249,6 +257,7 @@ async fn register_channel( .as_ref() .map(|s| s.as_ref() as &dyn SecretsStore), &channel_name, + &config.owner_id, ) .await { @@ -286,6 +295,7 @@ pub async fn inject_channel_credentials( channel: &Arc, secrets: Option<&dyn SecretsStore>, channel_name: &str, + owner_id: &str, ) -> anyhow::Result { if channel_name.trim().is_empty() { return Ok(0); @@ -297,7 +307,7 @@ pub async fn inject_channel_credentials( // 1. Try injecting from persistent secrets store if available if let Some(secrets) = secrets { let all_secrets = secrets - .list("default") + .list(owner_id) .await .map_err(|e| anyhow::anyhow!("Failed to list secrets: {}", e))?; @@ -308,7 +318,7 @@ pub async fn inject_channel_credentials( continue; } - let decrypted = match secrets.get_decrypted("default", &secret_meta.name).await { + let decrypted = match secrets.get_decrypted(owner_id, &secret_meta.name).await { Ok(d) => d, Err(e) => { tracing::warn!( diff --git a/src/channels/wasm/wrapper.rs b/src/channels/wasm/wrapper.rs index 1529da41b4..0be8756b1a 100644 --- a/src/channels/wasm/wrapper.rs +++ b/src/channels/wasm/wrapper.rs @@ -709,6 +709,12 @@ pub struct WasmChannel { /// Settings store for persisting broadcast metadata across restarts. settings_store: Option>, + /// Stable owner scope for persistent data and owner-target routing. + owner_scope_id: String, + + /// Channel-specific actor ID that maps to the instance owner on this channel. + owner_actor_id: Option, + /// Secrets store for host-based credential injection. /// Used to pre-resolve credentials before each WASM callback. secrets_store: Option>, @@ -719,6 +725,7 @@ pub struct WasmChannel { /// method and the static polling helper share one implementation. async fn do_update_broadcast_metadata( channel_name: &str, + owner_scope_id: &str, metadata: &str, last_broadcast_metadata: &tokio::sync::RwLock>, settings_store: Option<&Arc>, @@ -731,7 +738,7 @@ async fn do_update_broadcast_metadata( if changed && let Some(store) = settings_store { let key = format!("channel_broadcast_metadata_{}", channel_name); let value = serde_json::Value::String(metadata.to_string()); - if let Err(e) = store.set_setting("default", &key, &value).await { + if let Err(e) = store.set_setting(owner_scope_id, &key, &value).await { tracing::warn!( channel = %channel_name, "Failed to persist broadcast metadata: {}", @@ -741,12 +748,70 @@ async fn do_update_broadcast_metadata( } } +fn resolve_message_scope( + owner_scope_id: &str, + owner_actor_id: Option<&str>, + sender_id: &str, +) -> (String, bool) { + if owner_actor_id.is_some_and(|owner_actor_id| owner_actor_id == sender_id) { + (owner_scope_id.to_string(), true) + } else { + (sender_id.to_string(), false) + } +} + +fn uses_owner_broadcast_target(user_id: &str, owner_scope_id: &str) -> bool { + user_id == owner_scope_id +} + +fn missing_routing_target_error(name: &str, reason: String) -> ChannelError { + ChannelError::MissingRoutingTarget { + name: name.to_string(), + reason, + } +} + +fn resolve_owner_broadcast_target( + channel_name: &str, + metadata: &str, +) -> Result { + let metadata: serde_json::Value = serde_json::from_str(metadata).map_err(|e| { + missing_routing_target_error( + channel_name, + format!("Invalid stored owner routing metadata: {e}"), + ) + })?; + + crate::channels::routing_target_from_metadata(&metadata).ok_or_else(|| { + missing_routing_target_error( + channel_name, + format!( + "Stored owner routing metadata for channel '{}' is missing a delivery target.", + channel_name + ), + ) + }) +} + +fn apply_emitted_metadata(mut msg: IncomingMessage, metadata_json: &str) -> IncomingMessage { + if let Ok(metadata) = serde_json::from_str(metadata_json) { + msg = msg.with_metadata(metadata); + if msg.conversation_scope().is_none() + && let Some(scope_id) = crate::channels::routing_target_from_metadata(&msg.metadata) + { + msg = msg.with_conversation_scope(scope_id); + } + } + msg +} + impl WasmChannel { /// Create a new WASM channel. pub fn new( runtime: Arc, prepared: Arc, capabilities: ChannelCapabilities, + owner_scope_id: impl Into, config_json: String, pairing_store: Arc, settings_store: Option>, @@ -773,6 +838,8 @@ impl WasmChannel { workspace_store: Arc::new(ChannelWorkspaceStore::new()), last_broadcast_metadata: Arc::new(tokio::sync::RwLock::new(None)), settings_store, + owner_scope_id: owner_scope_id.into(), + owner_actor_id: None, secrets_store: None, } } @@ -787,6 +854,12 @@ impl WasmChannel { self } + /// Bind this channel to the external actor that maps to the configured owner. + pub fn with_owner_actor_id(mut self, owner_actor_id: Option) -> Self { + self.owner_actor_id = owner_actor_id; + self + } + /// Update the channel config before starting. /// /// Merges the provided values into the existing config JSON. @@ -843,6 +916,7 @@ impl WasmChannel { async fn update_broadcast_metadata(&self, metadata: &str) { do_update_broadcast_metadata( &self.name, + &self.owner_scope_id, metadata, &self.last_broadcast_metadata, self.settings_store.as_ref(), @@ -854,7 +928,7 @@ impl WasmChannel { async fn load_broadcast_metadata(&self) { if let Some(ref store) = self.settings_store { match store - .get_setting("default", &self.broadcast_metadata_key()) + .get_setting(&self.owner_scope_id, &self.broadcast_metadata_key()) .await { Ok(Some(serde_json::Value::String(meta))) => { @@ -864,7 +938,30 @@ impl WasmChannel { "Restored broadcast metadata from settings" ); } - Ok(_) => {} + Ok(_) => { + if self.owner_scope_id != "default" { + match store + .get_setting("default", &self.broadcast_metadata_key()) + .await + { + Ok(Some(serde_json::Value::String(meta))) => { + *self.last_broadcast_metadata.write().await = Some(meta); + tracing::debug!( + channel = %self.name, + "Restored legacy owner broadcast metadata from default scope" + ); + } + Ok(_) => {} + Err(e) => { + tracing::warn!( + channel = %self.name, + "Failed to load legacy broadcast metadata: {}", + e + ); + } + } + } + } Err(e) => { tracing::warn!( channel = %self.name, @@ -1064,9 +1161,12 @@ impl WasmChannel { let timeout = self.runtime.config().callback_timeout; let channel_name = self.name.clone(); let credentials = self.get_credentials().await; - let host_credentials = - resolve_channel_host_credentials(&self.capabilities, self.secrets_store.as_deref()) - .await; + let host_credentials = resolve_channel_host_credentials( + &self.capabilities, + self.secrets_store.as_deref(), + &self.owner_scope_id, + ) + .await; let pairing_store = self.pairing_store.clone(); let workspace_store = self.workspace_store.clone(); @@ -1204,9 +1304,12 @@ impl WasmChannel { let capabilities = Self::inject_workspace_reader(&self.capabilities, &self.workspace_store); let timeout = self.runtime.config().callback_timeout; let credentials = self.get_credentials().await; - let host_credentials = - resolve_channel_host_credentials(&self.capabilities, self.secrets_store.as_deref()) - .await; + let host_credentials = resolve_channel_host_credentials( + &self.capabilities, + self.secrets_store.as_deref(), + &self.owner_scope_id, + ) + .await; let pairing_store = self.pairing_store.clone(); let workspace_store = self.workspace_store.clone(); @@ -1307,9 +1410,12 @@ impl WasmChannel { let timeout = self.runtime.config().callback_timeout; let channel_name = self.name.clone(); let credentials = self.get_credentials().await; - let host_credentials = - resolve_channel_host_credentials(&self.capabilities, self.secrets_store.as_deref()) - .await; + let host_credentials = resolve_channel_host_credentials( + &self.capabilities, + self.secrets_store.as_deref(), + &self.owner_scope_id, + ) + .await; let pairing_store = self.pairing_store.clone(); let workspace_store = self.workspace_store.clone(); @@ -1414,9 +1520,12 @@ impl WasmChannel { let timeout = self.runtime.config().callback_timeout; let channel_name = self.name.clone(); let credentials = self.get_credentials().await; - let host_credentials = - resolve_channel_host_credentials(&self.capabilities, self.secrets_store.as_deref()) - .await; + let host_credentials = resolve_channel_host_credentials( + &self.capabilities, + self.secrets_store.as_deref(), + &self.owner_scope_id, + ) + .await; let pairing_store = self.pairing_store.clone(); // Prepare response data @@ -1555,9 +1664,12 @@ impl WasmChannel { let timeout = self.runtime.config().callback_timeout; let channel_name = self.name.clone(); let credentials = self.get_credentials().await; - let host_credentials = - resolve_channel_host_credentials(&self.capabilities, self.secrets_store.as_deref()) - .await; + let host_credentials = resolve_channel_host_credentials( + &self.capabilities, + self.secrets_store.as_deref(), + &self.owner_scope_id, + ) + .await; let pairing_store = self.pairing_store.clone(); let user_id = user_id.to_string(); @@ -1659,9 +1771,12 @@ impl WasmChannel { let timeout = self.runtime.config().callback_timeout; let channel_name = self.name.clone(); let credentials = self.get_credentials().await; - let host_credentials = - resolve_channel_host_credentials(&self.capabilities, self.secrets_store.as_deref()) - .await; + let host_credentials = resolve_channel_host_credentials( + &self.capabilities, + self.secrets_store.as_deref(), + &self.owner_scope_id, + ) + .await; let pairing_store = self.pairing_store.clone(); let Some(wit_update) = status_to_wit(status, metadata) else { @@ -1831,6 +1946,7 @@ impl WasmChannel { let repeater_host_credentials = resolve_channel_host_credentials( &self.capabilities, self.secrets_store.as_deref(), + &self.owner_scope_id, ) .await; let pairing_store = self.pairing_store.clone(); @@ -2027,8 +2143,16 @@ impl WasmChannel { } } + let (resolved_user_id, is_owner_sender) = resolve_message_scope( + &self.owner_scope_id, + self.owner_actor_id.as_deref(), + &emitted.user_id, + ); + // Convert to IncomingMessage - let mut msg = IncomingMessage::new(&self.name, &emitted.user_id, &emitted.content); + let mut msg = IncomingMessage::new(&self.name, &resolved_user_id, &emitted.content) + .with_owner_id(&self.owner_scope_id) + .with_sender_id(&emitted.user_id); if let Some(name) = emitted.user_name { msg = msg.with_user_name(name); @@ -2060,9 +2184,9 @@ impl WasmChannel { } // Parse metadata JSON - if let Ok(metadata) = serde_json::from_str(&emitted.metadata_json) { - msg = msg.with_metadata(metadata); - // Store for broadcast routing (chat_id etc.) + msg = apply_emitted_metadata(msg, &emitted.metadata_json); + if is_owner_sender { + // Store for owner-target routing (chat_id etc.). self.update_broadcast_metadata(&emitted.metadata_json).await; } @@ -2112,6 +2236,8 @@ impl WasmChannel { let last_broadcast_metadata = self.last_broadcast_metadata.clone(); let settings_store = self.settings_store.clone(); let poll_secrets_store = self.secrets_store.clone(); + let owner_scope_id = self.owner_scope_id.clone(); + let owner_actor_id = self.owner_actor_id.clone(); tokio::spawn(async move { let mut interval_timer = tokio::time::interval(interval); @@ -2129,6 +2255,7 @@ impl WasmChannel { let host_credentials = resolve_channel_host_credentials( &poll_capabilities, poll_secrets_store.as_deref(), + &owner_scope_id, ) .await; @@ -2150,12 +2277,16 @@ impl WasmChannel { // Process any emitted messages if !emitted_messages.is_empty() && let Err(e) = Self::dispatch_emitted_messages( - &channel_name, + EmitDispatchContext { + channel_name: &channel_name, + owner_scope_id: &owner_scope_id, + owner_actor_id: owner_actor_id.as_deref(), + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: settings_store.as_ref(), + }, emitted_messages, - &message_tx, - &rate_limiter, - &last_broadcast_metadata, - settings_store.as_ref(), ).await { tracing::warn!( channel = %channel_name, @@ -2277,25 +2408,21 @@ impl WasmChannel { /// This is a static helper used by the polling loop since it doesn't have /// access to `&self`. async fn dispatch_emitted_messages( - channel_name: &str, + dispatch: EmitDispatchContext<'_>, messages: Vec, - message_tx: &RwLock>>, - rate_limiter: &RwLock, - last_broadcast_metadata: &tokio::sync::RwLock>, - settings_store: Option<&Arc>, ) -> Result<(), WasmChannelError> { tracing::info!( - channel = %channel_name, + channel = %dispatch.channel_name, message_count = messages.len(), "Processing emitted messages from polling callback" ); // Clone sender to avoid holding RwLock read guard across send().await in the loop let tx = { - let tx_guard = message_tx.read().await; + let tx_guard = dispatch.message_tx.read().await; let Some(tx) = tx_guard.as_ref() else { tracing::error!( - channel = %channel_name, + channel = %dispatch.channel_name, count = messages.len(), "Messages emitted but no sender available - channel may not be started!" ); @@ -2307,20 +2434,29 @@ impl WasmChannel { for emitted in messages { // Check rate limit — acquire and release the write lock before send().await { - let mut limiter = rate_limiter.write().await; + let mut limiter = dispatch.rate_limiter.write().await; if !limiter.check_and_record() { tracing::warn!( - channel = %channel_name, + channel = %dispatch.channel_name, "Message emission rate limited" ); return Err(WasmChannelError::EmitRateLimited { - name: channel_name.to_string(), + name: dispatch.channel_name.to_string(), }); } } + let (resolved_user_id, is_owner_sender) = resolve_message_scope( + dispatch.owner_scope_id, + dispatch.owner_actor_id, + &emitted.user_id, + ); + // Convert to IncomingMessage - let mut msg = IncomingMessage::new(channel_name, &emitted.user_id, &emitted.content); + let mut msg = + IncomingMessage::new(dispatch.channel_name, &resolved_user_id, &emitted.content) + .with_owner_id(dispatch.owner_scope_id) + .with_sender_id(&emitted.user_id); if let Some(name) = emitted.user_name { msg = msg.with_user_name(name); @@ -2351,22 +2487,22 @@ impl WasmChannel { msg = msg.with_attachments(incoming_attachments); } - // Parse metadata JSON - if let Ok(metadata) = serde_json::from_str(&emitted.metadata_json) { - msg = msg.with_metadata(metadata); - // Store for broadcast routing (chat_id etc.) + msg = apply_emitted_metadata(msg, &emitted.metadata_json); + if is_owner_sender { + // Store for owner-target routing (chat_id etc.) do_update_broadcast_metadata( - channel_name, + dispatch.channel_name, + dispatch.owner_scope_id, &emitted.metadata_json, - last_broadcast_metadata, - settings_store, + dispatch.last_broadcast_metadata, + dispatch.settings_store, ) .await; } // Send to stream — no locks held across this await tracing::info!( - channel = %channel_name, + channel = %dispatch.channel_name, user_id = %emitted.user_id, content_len = emitted.content.len(), attachment_count = msg.attachments.len(), @@ -2375,14 +2511,14 @@ impl WasmChannel { if tx.send(msg).await.is_err() { tracing::error!( - channel = %channel_name, + channel = %dispatch.channel_name, "Failed to send polled message, channel closed" ); break; } tracing::info!( - channel = %channel_name, + channel = %dispatch.channel_name, "Message successfully sent to agent queue" ); } @@ -2391,6 +2527,16 @@ impl WasmChannel { } } +struct EmitDispatchContext<'a> { + channel_name: &'a str, + owner_scope_id: &'a str, + owner_actor_id: Option<&'a str>, + message_tx: &'a RwLock>>, + rate_limiter: &'a RwLock, + last_broadcast_metadata: &'a tokio::sync::RwLock>, + settings_store: Option<&'a Arc>, +} + #[async_trait] impl Channel for WasmChannel { fn name(&self) -> &str { @@ -2490,8 +2636,11 @@ impl Channel for WasmChannel { // The original metadata contains channel-specific routing info (e.g., Telegram chat_id) // that the WASM channel needs to send the reply to the correct destination. let metadata_json = serde_json::to_string(&msg.metadata).unwrap_or_default(); - // Store for broadcast routing (chat_id etc.) - self.update_broadcast_metadata(&metadata_json).await; + // Store for owner-target routing (chat_id etc.) only when the configured + // owner is the actor in this conversation. + if msg.user_id == self.owner_scope_id { + self.update_broadcast_metadata(&metadata_json).await; + } self.call_on_respond( msg.id, &response.content, @@ -2514,8 +2663,24 @@ impl Channel for WasmChannel { response: OutgoingResponse, ) -> Result<(), ChannelError> { self.cancel_typing_task().await; + let resolved_target = if uses_owner_broadcast_target(user_id, &self.owner_scope_id) { + let metadata = self.last_broadcast_metadata.read().await.clone().ok_or_else(|| { + missing_routing_target_error( + &self.name, + format!( + "No stored owner routing target for channel '{}'. Send a message from the owner on this channel first.", + self.name + ), + ) + })?; + + resolve_owner_broadcast_target(&self.name, &metadata)? + } else { + user_id.to_string() + }; + self.call_on_broadcast( - user_id, + &resolved_target, &response.content, response.thread_id.as_deref(), &response.attachments, @@ -2931,6 +3096,7 @@ fn extract_host_from_url(url: &str) -> Option { async fn resolve_channel_host_credentials( capabilities: &ChannelCapabilities, store: Option<&(dyn SecretsStore + Send + Sync)>, + owner_scope_id: &str, ) -> Vec { let store = match store { Some(s) => s, @@ -2957,7 +3123,10 @@ async fn resolve_channel_host_credentials( continue; } - let secret = match store.get_decrypted("default", &mapping.secret_name).await { + let secret = match store + .get_decrypted(owner_scope_id, &mapping.secret_name) + .await + { Ok(s) => s, Err(e) => { tracing::debug!( @@ -3076,12 +3245,18 @@ mod tests { use crate::channels::wasm::runtime::{ PreparedChannelModule, WasmChannelRuntime, WasmChannelRuntimeConfig, }; - use crate::channels::wasm::wrapper::{HttpResponse, WasmChannel}; + use crate::channels::wasm::wrapper::{ + EmitDispatchContext, HttpResponse, WasmChannel, uses_owner_broadcast_target, + }; use crate::pairing::PairingStore; use crate::testing::credentials::TEST_TELEGRAM_BOT_TOKEN; use crate::tools::wasm::ResourceLimits; fn create_test_channel() -> WasmChannel { + create_test_channel_with_owner_scope("default") + } + + fn create_test_channel_with_owner_scope(owner_scope_id: &str) -> WasmChannel { let config = WasmChannelRuntimeConfig::for_testing(); let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap()); @@ -3098,6 +3273,7 @@ mod tests { runtime, prepared, capabilities, + owner_scope_id, "{}".to_string(), Arc::new(PairingStore::new()), None, @@ -3185,7 +3361,7 @@ mod tests { ) .await; - assert!(result.is_ok()); + assert!(result.is_ok()); // safety: test-only assertion assert!(result.unwrap().is_empty()); } @@ -3209,28 +3385,32 @@ mod tests { let last_broadcast_metadata = Arc::new(tokio::sync::RwLock::new(None)); let result = WasmChannel::dispatch_emitted_messages( - "test-channel", + EmitDispatchContext { + channel_name: "test-channel", + owner_scope_id: "default", + owner_actor_id: None, + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: None, + }, messages, - &message_tx, - &rate_limiter, - &last_broadcast_metadata, - None, ) .await; - assert!(result.is_ok()); + assert!(result.is_ok()); // safety: test-only assertion // Verify messages were sent - let msg1 = rx.try_recv().expect("Should receive first message"); - assert_eq!(msg1.user_id, "user1"); - assert_eq!(msg1.content, "Hello from polling!"); + let msg1 = rx.try_recv().expect("Should receive first message"); // safety: test-only assertion + assert_eq!(msg1.user_id, "user1"); // safety: test-only assertion + assert_eq!(msg1.content, "Hello from polling!"); // safety: test-only assertion - let msg2 = rx.try_recv().expect("Should receive second message"); - assert_eq!(msg2.user_id, "user2"); - assert_eq!(msg2.content, "Another message"); + let msg2 = rx.try_recv().expect("Should receive second message"); // safety: test-only assertion + assert_eq!(msg2.user_id, "user2"); // safety: test-only assertion + assert_eq!(msg2.content, "Another message"); // safety: test-only assertion // No more messages - assert!(rx.try_recv().is_err()); + assert!(rx.try_recv().is_err()); // safety: test-only assertion } #[tokio::test] @@ -3250,12 +3430,16 @@ mod tests { // Should return Ok even without a sender (logs warning but doesn't fail) let last_broadcast_metadata = Arc::new(tokio::sync::RwLock::new(None)); let result = WasmChannel::dispatch_emitted_messages( - "test-channel", + EmitDispatchContext { + channel_name: "test-channel", + owner_scope_id: "default", + owner_actor_id: None, + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: None, + }, messages, - &message_tx, - &rate_limiter, - &last_broadcast_metadata, - None, ) .await; @@ -3284,6 +3468,7 @@ mod tests { runtime, prepared, capabilities, + "default", "{}".to_string(), Arc::new(PairingStore::new()), None, @@ -4255,42 +4440,172 @@ mod tests { let last_broadcast_metadata = Arc::new(tokio::sync::RwLock::new(None)); let result = WasmChannel::dispatch_emitted_messages( - "test-channel", + EmitDispatchContext { + channel_name: "test-channel", + owner_scope_id: "default", + owner_actor_id: None, + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: None, + }, messages, - &message_tx, - &rate_limiter, - &last_broadcast_metadata, - None, ) .await; - assert!(result.is_ok()); + assert!(result.is_ok()); // safety: test-only assertion - let msg = rx.try_recv().expect("Should receive message"); - assert_eq!(msg.content, "Check these files"); - assert_eq!(msg.attachments.len(), 2); + let msg = rx.try_recv().expect("Should receive message"); // safety: test-only assertion + assert_eq!(msg.content, "Check these files"); // safety: test-only assertion + assert_eq!(msg.attachments.len(), 2); // safety: test-only assertion // Verify first attachment - assert_eq!(msg.attachments[0].id, "photo123"); - assert_eq!(msg.attachments[0].mime_type, "image/jpeg"); - assert_eq!(msg.attachments[0].filename, Some("cat.jpg".to_string())); - assert_eq!(msg.attachments[0].size_bytes, Some(50_000)); + assert_eq!(msg.attachments[0].id, "photo123"); // safety: test-only assertion + assert_eq!(msg.attachments[0].mime_type, "image/jpeg"); // safety: test-only assertion + assert_eq!(msg.attachments[0].filename, Some("cat.jpg".to_string())); // safety: test-only assertion + assert_eq!(msg.attachments[0].size_bytes, Some(50_000)); // safety: test-only assertion assert_eq!( msg.attachments[0].source_url, Some("https://api.telegram.org/file/photo123".to_string()) - ); + ); // safety: test-only assertion // Verify second attachment - assert_eq!(msg.attachments[1].id, "doc456"); - assert_eq!(msg.attachments[1].mime_type, "application/pdf"); + assert_eq!(msg.attachments[1].id, "doc456"); // safety: test-only assertion + assert_eq!(msg.attachments[1].mime_type, "application/pdf"); // safety: test-only assertion assert_eq!( msg.attachments[1].extracted_text, Some("Report contents...".to_string()) - ); + ); // safety: test-only assertion assert_eq!( msg.attachments[1].storage_key, Some("store/doc456".to_string()) - ); + ); // safety: test-only assertion + } + + #[tokio::test] + async fn test_dispatch_emitted_messages_owner_binding_sets_owner_scope() { + use crate::channels::wasm::host::EmittedMessage; + + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + let message_tx = Arc::new(tokio::sync::RwLock::new(Some(tx))); + let rate_limiter = Arc::new(tokio::sync::RwLock::new( + crate::channels::wasm::host::ChannelEmitRateLimiter::new( + crate::channels::wasm::capabilities::EmitRateLimitConfig::default(), + ), + )); + let last_broadcast_metadata = Arc::new(tokio::sync::RwLock::new(None)); + + let messages = vec![ + EmittedMessage::new("telegram-owner", "Hello from owner") + .with_metadata(r#"{"chat_id":12345}"#), + ]; + + let result = WasmChannel::dispatch_emitted_messages( + EmitDispatchContext { + channel_name: "telegram", + owner_scope_id: "owner-scope", + owner_actor_id: Some("telegram-owner"), + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: None, + }, + messages, + ) + .await; + + assert!(result.is_ok()); // safety: test-only assertion + + let msg = rx.try_recv().expect("Should receive message"); // safety: test-only assertion + assert_eq!(msg.user_id, "owner-scope"); // safety: test-only assertion + assert_eq!(msg.owner_id, "owner-scope"); // safety: test-only assertion + assert_eq!(msg.sender_id, "telegram-owner"); // safety: test-only assertion + assert_eq!(msg.conversation_scope(), Some("12345")); // safety: test-only assertion + let stored_metadata = last_broadcast_metadata.read().await.clone(); + assert_eq!(stored_metadata.as_deref(), Some(r#"{"chat_id":12345}"#)); // safety: test-only assertion + } + + #[tokio::test] + async fn test_dispatch_emitted_messages_guest_sender_stays_isolated() { + use crate::channels::wasm::host::EmittedMessage; + + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + let message_tx = Arc::new(tokio::sync::RwLock::new(Some(tx))); + let rate_limiter = Arc::new(tokio::sync::RwLock::new( + crate::channels::wasm::host::ChannelEmitRateLimiter::new( + crate::channels::wasm::capabilities::EmitRateLimitConfig::default(), + ), + )); + let last_broadcast_metadata = Arc::new(tokio::sync::RwLock::new(None)); + + let messages = vec![ + EmittedMessage::new("guest-42", "Hello from guest").with_metadata(r#"{"chat_id":999}"#), + ]; + + let result = WasmChannel::dispatch_emitted_messages( + EmitDispatchContext { + channel_name: "telegram", + owner_scope_id: "owner-scope", + owner_actor_id: Some("telegram-owner"), + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: None, + }, + messages, + ) + .await; + + assert!(result.is_ok()); // safety: test-only assertion + + let msg = rx.try_recv().expect("Should receive message"); // safety: test-only assertion + assert_eq!(msg.user_id, "guest-42"); // safety: test-only assertion + assert_eq!(msg.owner_id, "owner-scope"); // safety: test-only assertion + assert_eq!(msg.sender_id, "guest-42"); // safety: test-only assertion + assert_eq!(msg.conversation_scope(), Some("999")); // safety: test-only assertion + assert!(last_broadcast_metadata.read().await.is_none()); // safety: test-only assertion + } + + #[tokio::test] + async fn test_broadcast_owner_scope_uses_stored_owner_metadata() { + let channel = create_test_channel_with_owner_scope("owner-scope") + .with_owner_actor_id(Some("telegram-owner".to_string())); + + *channel.last_broadcast_metadata.write().await = Some(r#"{"chat_id":12345}"#.to_string()); + + let result = channel + .broadcast( + "owner-scope", + crate::channels::OutgoingResponse::text("hello owner"), + ) + .await; + + assert!(result.is_ok()); // safety: test-only assertion + } + + #[test] + fn test_default_target_is_not_treated_as_owner_scope() { + assert!(!uses_owner_broadcast_target("default", "owner-scope")); // safety: test-only assertion + assert!(uses_owner_broadcast_target("default", "default")); // safety: test-only assertion + } + + #[tokio::test] + async fn test_broadcast_owner_scope_requires_stored_metadata() { + let channel = create_test_channel_with_owner_scope("owner-scope") + .with_owner_actor_id(Some("telegram-owner".to_string())); + + let result = channel + .broadcast( + "owner-scope", + crate::channels::OutgoingResponse::text("hello owner"), + ) + .await; + + assert!(result.is_err()); // safety: test-only assertion + let err = result.unwrap_err().to_string(); + let mentions_missing_owner_route = + err.contains("Send a message from the owner on this channel first"); + assert!(mentions_missing_owner_route); // safety: test-only assertion } #[tokio::test] @@ -4310,20 +4625,24 @@ mod tests { let last_broadcast_metadata = Arc::new(tokio::sync::RwLock::new(None)); let result = WasmChannel::dispatch_emitted_messages( - "test-channel", + EmitDispatchContext { + channel_name: "test-channel", + owner_scope_id: "default", + owner_actor_id: None, + message_tx: &message_tx, + rate_limiter: &rate_limiter, + last_broadcast_metadata: &last_broadcast_metadata, + settings_store: None, + }, messages, - &message_tx, - &rate_limiter, - &last_broadcast_metadata, - None, ) .await; - assert!(result.is_ok()); + assert!(result.is_ok()); // safety: test-only assertion - let msg = rx.try_recv().expect("Should receive message"); - assert_eq!(msg.content, "Just text, no attachments"); - assert!(msg.attachments.is_empty()); + let msg = rx.try_recv().expect("Should receive message"); // safety: test-only assertion + assert_eq!(msg.content, "Just text, no attachments"); // safety: test-only assertion + assert!(msg.attachments.is_empty()); // safety: test-only assertion } #[test] diff --git a/src/channels/web/server.rs b/src/channels/web/server.rs index fb8c93ae23..1eb49e3cf5 100644 --- a/src/channels/web/server.rs +++ b/src/channels/web/server.rs @@ -26,7 +26,6 @@ use tower_http::set_header::SetResponseHeaderLayer; use uuid::Uuid; use crate::agent::SessionManager; -use crate::agent::routine::{Trigger, next_cron_fire}; use crate::bootstrap::ironclaw_base_dir; use crate::channels::IncomingMessage; use crate::channels::relay::DEFAULT_RELAY_NAME; @@ -36,6 +35,7 @@ use crate::channels::web::handlers::jobs::{ jobs_events_handler, jobs_list_handler, jobs_prompt_handler, jobs_restart_handler, jobs_summary_handler, }; +use crate::channels::web::handlers::routines::{routines_delete_handler, routines_toggle_handler}; use crate::channels::web::handlers::skills::{ skills_install_handler, skills_list_handler, skills_remove_handler, skills_search_handler, }; @@ -2470,83 +2470,6 @@ async fn routines_trigger_handler( }))) } -#[derive(Deserialize)] -struct ToggleRequest { - enabled: Option, -} - -async fn routines_toggle_handler( - State(state): State>, - Path(id): Path, - body: Option>, -) -> Result, (StatusCode, String)> { - let store = state.store.as_ref().ok_or(( - StatusCode::SERVICE_UNAVAILABLE, - "Database not available".to_string(), - ))?; - - let routine_id = Uuid::parse_str(&id) - .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid routine ID".to_string()))?; - - let mut routine = store - .get_routine(routine_id) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? - .ok_or((StatusCode::NOT_FOUND, "Routine not found".to_string()))?; - - let was_enabled = routine.enabled; - // If a specific value was provided, use it; otherwise toggle. - routine.enabled = match body { - Some(Json(req)) => req.enabled.unwrap_or(!routine.enabled), - None => !routine.enabled, - }; - - if routine.enabled - && !was_enabled - && let Trigger::Cron { schedule, timezone } = &routine.trigger - { - routine.next_fire_at = next_cron_fire(schedule, timezone.as_deref()) - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; - } - - store - .update_routine(&routine) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; - - Ok(Json(serde_json::json!({ - "status": if routine.enabled { "enabled" } else { "disabled" }, - "routine_id": routine_id, - }))) -} - -async fn routines_delete_handler( - State(state): State>, - Path(id): Path, -) -> Result, (StatusCode, String)> { - let store = state.store.as_ref().ok_or(( - StatusCode::SERVICE_UNAVAILABLE, - "Database not available".to_string(), - ))?; - - let routine_id = Uuid::parse_str(&id) - .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid routine ID".to_string()))?; - - let deleted = store - .delete_routine(routine_id) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; - - if deleted { - Ok(Json(serde_json::json!({ - "status": "deleted", - "routine_id": routine_id, - }))) - } else { - Err((StatusCode::NOT_FOUND, "Routine not found".to_string())) - } -} - async fn routines_runs_handler( State(state): State>, Path(id): Path, diff --git a/src/cli/doctor.rs b/src/cli/doctor.rs index ee0b2be8b0..dfc04de767 100644 --- a/src/cli/doctor.rs +++ b/src/cli/doctor.rs @@ -405,10 +405,11 @@ fn check_routines_config() -> CheckResult { fn check_gateway_config(settings: &Settings) -> CheckResult { // Use the same resolve() path as runtime so invalid env values // (e.g. GATEWAY_PORT=abc) are caught here too. - let tunnel_enabled = crate::config::TunnelConfig::resolve(settings) - .map(|t| t.is_enabled()) - .unwrap_or(false); - match crate::config::ChannelsConfig::resolve(settings, tunnel_enabled) { + let owner_id = match crate::config::resolve_owner_id(settings) { + Ok(owner_id) => owner_id, + Err(e) => return CheckResult::Fail(format!("config error: {e}")), + }; + match crate::config::ChannelsConfig::resolve(settings, &owner_id) { Ok(channels) => match channels.gateway { Some(gw) => { if gw.auth_token.is_some() { diff --git a/src/cli/routines.rs b/src/cli/routines.rs index 852fc41fdd..dd8a2fa354 100644 --- a/src/cli/routines.rs +++ b/src/cli/routines.rs @@ -292,6 +292,16 @@ async fn list( // ── Create ────────────────────────────────────────────────── +fn cli_notify_config(notify_channel: Option) -> NotifyConfig { + NotifyConfig { + channel: notify_channel, + user: None, + on_attention: true, + on_failure: true, + on_success: false, + } +} + #[allow(clippy::too_many_arguments)] async fn create( db: &Arc, @@ -338,13 +348,7 @@ async fn create( max_concurrent: 1, dedup_window: None, }, - notify: NotifyConfig { - channel: notify_channel, - user: user_id.to_string(), - on_attention: true, - on_failure: true, - on_success: false, - }, + notify: cli_notify_config(notify_channel), last_run_at: None, next_fire_at: next_fire, run_count: 0, @@ -729,4 +733,14 @@ mod tests { // Must be valid UTF-8 (would have panicked otherwise). assert!(result.is_char_boundary(result.len())); } + + #[test] + fn cli_notify_config_defaults_to_runtime_target_resolution() { + let notify = cli_notify_config(Some("telegram".to_string())); + assert_eq!(notify.channel.as_deref(), Some("telegram")); // safety: test-only assertion + assert_eq!(notify.user, None); // safety: test-only assertion + assert!(notify.on_attention); // safety: test-only assertion + assert!(notify.on_failure); // safety: test-only assertion + assert!(!notify.on_success); // safety: test-only assertion + } } diff --git a/src/config/channels.rs b/src/config/channels.rs index 511f31c73b..6b1058a0e3 100644 --- a/src/config/channels.rs +++ b/src/config/channels.rs @@ -91,36 +91,24 @@ pub struct SignalConfig { } impl ChannelsConfig { - /// Resolve channels config following `env > settings > default` for every field. - pub(crate) fn resolve(settings: &Settings, tunnel_enabled: bool) -> Result { + pub(crate) fn resolve(settings: &Settings, owner_id: &str) -> Result { let cs = &settings.channels; - // --- HTTP webhook --- - // HTTP is enabled when env vars are set OR settings has it enabled. let http_enabled_by_env = optional_env("HTTP_PORT")?.is_some() || optional_env("HTTP_HOST")?.is_some(); - // When a tunnel is configured, default to loopback since external - // traffic arrives through the tunnel. Without a tunnel the webhook - // server needs to accept connections from the network directly. - let default_host = if tunnel_enabled { - "127.0.0.1" - } else { - "0.0.0.0" - }; let http = if http_enabled_by_env || cs.http_enabled { Some(HttpConfig { host: optional_env("HTTP_HOST")? .or_else(|| cs.http_host.clone()) - .unwrap_or_else(|| default_host.to_string()), + .unwrap_or_else(|| "0.0.0.0".to_string()), port: parse_optional_env("HTTP_PORT", cs.http_port.unwrap_or(8080))?, webhook_secret: optional_env("HTTP_WEBHOOK_SECRET")?.map(SecretString::from), - user_id: optional_env("HTTP_USER_ID")?.unwrap_or_else(|| "http".to_string()), + user_id: owner_id.to_string(), }) } else { None }; - // --- Web gateway --- let gateway_enabled = parse_bool_env("GATEWAY_ENABLED", cs.gateway_enabled)?; let gateway = if gateway_enabled { Some(GatewayConfig { @@ -133,33 +121,29 @@ impl ChannelsConfig { )?, auth_token: optional_env("GATEWAY_AUTH_TOKEN")? .or_else(|| cs.gateway_auth_token.clone()), - user_id: optional_env("GATEWAY_USER_ID")? - .or_else(|| cs.gateway_user_id.clone()) - .unwrap_or_else(|| "default".to_string()), + user_id: owner_id.to_string(), }) } else { None }; - // --- Signal --- let signal_url = optional_env("SIGNAL_HTTP_URL")?.or_else(|| cs.signal_http_url.clone()); let signal = if let Some(http_url) = signal_url { let account = optional_env("SIGNAL_ACCOUNT")? .or_else(|| cs.signal_account.clone()) .ok_or(ConfigError::InvalidValue { key: "SIGNAL_ACCOUNT".to_string(), - message: "SIGNAL_ACCOUNT is required when Signal is enabled".to_string(), + message: "SIGNAL_ACCOUNT is required when SIGNAL_HTTP_URL is set".to_string(), })?; - let allow_from_str = - optional_env("SIGNAL_ALLOW_FROM")?.or_else(|| cs.signal_allow_from.clone()); - let allow_from = match allow_from_str { - None => vec![account.clone()], - Some(s) => s - .split(',') - .map(|e| e.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect(), - }; + let allow_from = + match optional_env("SIGNAL_ALLOW_FROM")?.or_else(|| cs.signal_allow_from.clone()) { + None => vec![account.clone()], + Some(s) => s + .split(',') + .map(|e| e.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(), + }; let dm_policy = optional_env("SIGNAL_DM_POLICY")? .or_else(|| cs.signal_dm_policy.clone()) .unwrap_or_else(|| "pairing".to_string()); @@ -201,18 +185,8 @@ impl ChannelsConfig { None }; - // --- CLI --- let cli_enabled = parse_bool_env("CLI_ENABLED", cs.cli_enabled)?; - // --- WASM channels --- - let wasm_channels_dir = optional_env("WASM_CHANNELS_DIR")? - .map(PathBuf::from) - .or_else(|| cs.wasm_channels_dir.clone()) - .unwrap_or_else(default_channels_dir); - - let wasm_channels_enabled = - parse_bool_env("WASM_CHANNELS_ENABLED", cs.wasm_channels_enabled)?; - Ok(Self { cli: CliConfig { enabled: cli_enabled, @@ -220,8 +194,14 @@ impl ChannelsConfig { http, gateway, signal, - wasm_channels_dir, - wasm_channels_enabled, + wasm_channels_dir: optional_env("WASM_CHANNELS_DIR")? + .map(PathBuf::from) + .or_else(|| cs.wasm_channels_dir.clone()) + .unwrap_or_else(default_channels_dir), + wasm_channels_enabled: parse_bool_env( + "WASM_CHANNELS_ENABLED", + cs.wasm_channels_enabled, + )?, wasm_channel_owner_ids: { let mut ids = cs.wasm_channel_owner_ids.clone(); // Backwards compat: TELEGRAM_OWNER_ID env var @@ -252,6 +232,8 @@ fn default_channels_dir() -> PathBuf { #[cfg(test)] mod tests { use crate::config::channels::*; + use crate::config::helpers::ENV_MUTEX; + use crate::settings::Settings; #[test] fn cli_config_fields() { @@ -398,69 +380,6 @@ mod tests { assert!(!cfg.wasm_channels_enabled); } - /// When a tunnel is active and HTTP_HOST is not explicitly set, the - /// webhook server should default to loopback to avoid unnecessary exposure. - #[test] - fn http_host_defaults_to_loopback_with_tunnel() { - // Set HTTP_PORT to trigger HttpConfig creation, but leave HTTP_HOST unset - // so the default kicks in. - unsafe { - std::env::set_var("HTTP_PORT", "9999"); - std::env::remove_var("HTTP_HOST"); - } - let settings = crate::settings::Settings::default(); - let cfg = ChannelsConfig::resolve(&settings, true).unwrap(); - unsafe { - std::env::remove_var("HTTP_PORT"); - } - let http = cfg.http.expect("HttpConfig should be present"); - assert_eq!( - http.host, "127.0.0.1", - "tunnel active should default to loopback" - ); - assert_eq!(http.port, 9999); - } - - /// Without a tunnel, the webhook server defaults to 0.0.0.0 so external - /// services can reach it directly. - #[test] - fn http_host_defaults_to_all_interfaces_without_tunnel() { - unsafe { - std::env::set_var("HTTP_PORT", "9998"); - std::env::remove_var("HTTP_HOST"); - } - let settings = crate::settings::Settings::default(); - let cfg = ChannelsConfig::resolve(&settings, false).unwrap(); - unsafe { - std::env::remove_var("HTTP_PORT"); - } - let http = cfg.http.expect("HttpConfig should be present"); - assert_eq!( - http.host, "0.0.0.0", - "no tunnel should default to all interfaces" - ); - } - - /// An explicit HTTP_HOST always wins regardless of tunnel state. - #[test] - fn explicit_http_host_overrides_tunnel_default() { - unsafe { - std::env::set_var("HTTP_PORT", "9997"); - std::env::set_var("HTTP_HOST", "192.168.1.50"); - } - let settings = crate::settings::Settings::default(); - let cfg = ChannelsConfig::resolve(&settings, true).unwrap(); - unsafe { - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - } - let http = cfg.http.expect("HttpConfig should be present"); - assert_eq!( - http.host, "192.168.1.50", - "explicit host should override tunnel default" - ); - } - #[test] fn default_channels_dir_ends_with_channels() { let dir = default_channels_dir(); @@ -471,242 +390,43 @@ mod tests { } #[test] - fn default_gateway_port_constant() { - assert_eq!(DEFAULT_GATEWAY_PORT, 3000); - } - - /// With default settings and no env vars, gateway should use defaults. - #[test] - fn resolve_gateway_defaults_from_settings() { - let _lock = crate::config::helpers::ENV_MUTEX.lock(); - // Clear env vars that would interfere - unsafe { - std::env::remove_var("GATEWAY_ENABLED"); - std::env::remove_var("GATEWAY_HOST"); - std::env::remove_var("GATEWAY_PORT"); - std::env::remove_var("GATEWAY_AUTH_TOKEN"); - std::env::remove_var("GATEWAY_USER_ID"); - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - std::env::remove_var("SIGNAL_HTTP_URL"); - std::env::remove_var("CLI_ENABLED"); - std::env::remove_var("WASM_CHANNELS_DIR"); - std::env::remove_var("WASM_CHANNELS_ENABLED"); - std::env::remove_var("TELEGRAM_OWNER_ID"); - } - - let settings = crate::settings::Settings::default(); - let cfg = ChannelsConfig::resolve(&settings, false).unwrap(); - - let gw = cfg.gateway.expect("gateway should be enabled by default"); - assert_eq!(gw.host, "127.0.0.1"); - assert_eq!(gw.port, DEFAULT_GATEWAY_PORT); - assert!(gw.auth_token.is_none()); - assert_eq!(gw.user_id, "default"); - } - - /// Settings values should be used when no env vars are set. - #[test] - fn resolve_gateway_from_settings() { - let _lock = crate::config::helpers::ENV_MUTEX.lock(); - unsafe { - std::env::remove_var("GATEWAY_ENABLED"); - std::env::remove_var("GATEWAY_HOST"); - std::env::remove_var("GATEWAY_PORT"); - std::env::remove_var("GATEWAY_AUTH_TOKEN"); - std::env::remove_var("GATEWAY_USER_ID"); - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - std::env::remove_var("SIGNAL_HTTP_URL"); - std::env::remove_var("CLI_ENABLED"); - std::env::remove_var("WASM_CHANNELS_DIR"); - std::env::remove_var("WASM_CHANNELS_ENABLED"); - std::env::remove_var("TELEGRAM_OWNER_ID"); - } - - let mut settings = crate::settings::Settings::default(); - settings.channels.gateway_port = Some(4000); - settings.channels.gateway_host = Some("0.0.0.0".to_string()); - settings.channels.gateway_auth_token = Some("db-token-123".to_string()); - settings.channels.gateway_user_id = Some("myuser".to_string()); - - let cfg = ChannelsConfig::resolve(&settings, false).unwrap(); - let gw = cfg.gateway.expect("gateway should be enabled"); - assert_eq!(gw.port, 4000); - assert_eq!(gw.host, "0.0.0.0"); - assert_eq!(gw.auth_token.as_deref(), Some("db-token-123")); - assert_eq!(gw.user_id, "myuser"); - } - - /// Env vars should override settings values. - #[test] - fn resolve_env_overrides_settings() { - let _lock = crate::config::helpers::ENV_MUTEX.lock(); - unsafe { - std::env::set_var("GATEWAY_PORT", "5000"); - std::env::set_var("GATEWAY_HOST", "10.0.0.1"); - std::env::set_var("GATEWAY_AUTH_TOKEN", "env-token"); - std::env::remove_var("GATEWAY_ENABLED"); - std::env::remove_var("GATEWAY_USER_ID"); - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - std::env::remove_var("SIGNAL_HTTP_URL"); - std::env::remove_var("CLI_ENABLED"); - std::env::remove_var("WASM_CHANNELS_DIR"); - std::env::remove_var("WASM_CHANNELS_ENABLED"); - std::env::remove_var("TELEGRAM_OWNER_ID"); - } - - let mut settings = crate::settings::Settings::default(); - settings.channels.gateway_port = Some(4000); - settings.channels.gateway_host = Some("0.0.0.0".to_string()); - settings.channels.gateway_auth_token = Some("db-token".to_string()); - - let cfg = ChannelsConfig::resolve(&settings, false).unwrap(); - let gw = cfg.gateway.expect("gateway should be enabled"); - assert_eq!(gw.port, 5000, "env should override settings"); - assert_eq!(gw.host, "10.0.0.1", "env should override settings"); - assert_eq!( - gw.auth_token.as_deref(), - Some("env-token"), - "env should override settings" - ); - - // Cleanup - unsafe { - std::env::remove_var("GATEWAY_PORT"); - std::env::remove_var("GATEWAY_HOST"); - std::env::remove_var("GATEWAY_AUTH_TOKEN"); - } - } - - /// CLI enabled should fall back to settings. - #[test] - fn resolve_cli_enabled_from_settings() { - let _lock = crate::config::helpers::ENV_MUTEX.lock(); - unsafe { - std::env::remove_var("CLI_ENABLED"); - std::env::remove_var("GATEWAY_ENABLED"); - std::env::remove_var("GATEWAY_HOST"); - std::env::remove_var("GATEWAY_PORT"); - std::env::remove_var("GATEWAY_AUTH_TOKEN"); - std::env::remove_var("GATEWAY_USER_ID"); - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - std::env::remove_var("SIGNAL_HTTP_URL"); - std::env::remove_var("WASM_CHANNELS_DIR"); - std::env::remove_var("WASM_CHANNELS_ENABLED"); - std::env::remove_var("TELEGRAM_OWNER_ID"); - } - - let mut settings = crate::settings::Settings::default(); - settings.channels.cli_enabled = false; - - let cfg = ChannelsConfig::resolve(&settings, false).unwrap(); - assert!(!cfg.cli.enabled, "settings should disable CLI"); - } - - /// HTTP channel should activate when settings has it enabled. - #[test] - fn resolve_http_from_settings() { - let _lock = crate::config::helpers::ENV_MUTEX.lock(); - unsafe { - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - std::env::remove_var("HTTP_WEBHOOK_SECRET"); - std::env::remove_var("HTTP_USER_ID"); - std::env::remove_var("GATEWAY_ENABLED"); - std::env::remove_var("GATEWAY_HOST"); - std::env::remove_var("GATEWAY_PORT"); - std::env::remove_var("GATEWAY_AUTH_TOKEN"); - std::env::remove_var("GATEWAY_USER_ID"); - std::env::remove_var("SIGNAL_HTTP_URL"); - std::env::remove_var("CLI_ENABLED"); - std::env::remove_var("WASM_CHANNELS_DIR"); - std::env::remove_var("WASM_CHANNELS_ENABLED"); - std::env::remove_var("TELEGRAM_OWNER_ID"); - } - - let mut settings = crate::settings::Settings::default(); + fn resolve_uses_settings_channel_values_with_owner_scope_user_ids() { + let _guard = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()); + let mut settings = Settings::default(); settings.channels.http_enabled = true; - settings.channels.http_port = Some(9090); - settings.channels.http_host = Some("10.0.0.1".to_string()); - - let cfg = ChannelsConfig::resolve(&settings, false).unwrap(); - let http = cfg.http.expect("HTTP should be enabled from settings"); - assert_eq!(http.port, 9090); - assert_eq!(http.host, "10.0.0.1"); - } + settings.channels.http_host = Some("127.0.0.2".to_string()); + settings.channels.http_port = Some(8181); + settings.channels.gateway_enabled = true; + settings.channels.gateway_host = Some("127.0.0.3".to_string()); + settings.channels.gateway_port = Some(9191); + settings.channels.gateway_auth_token = Some("tok".to_string()); + settings.channels.signal_http_url = Some("http://127.0.0.1:8080".to_string()); + settings.channels.signal_account = Some("+15551234567".to_string()); + settings.channels.signal_allow_from = Some("+15551234567,+15557654321".to_string()); + settings.channels.wasm_channels_dir = Some(PathBuf::from("/tmp/settings-channels")); + settings.channels.wasm_channels_enabled = false; + + let cfg = ChannelsConfig::resolve(&settings, "owner-scope").expect("resolve"); + + let http = cfg.http.expect("http config"); + assert_eq!(http.host, "127.0.0.2"); + assert_eq!(http.port, 8181); + assert_eq!(http.user_id, "owner-scope"); + + let gateway = cfg.gateway.expect("gateway config"); + assert_eq!(gateway.host, "127.0.0.3"); + assert_eq!(gateway.port, 9191); + assert_eq!(gateway.auth_token.as_deref(), Some("tok")); + assert_eq!(gateway.user_id, "owner-scope"); + + let signal = cfg.signal.expect("signal config"); + assert_eq!(signal.account, "+15551234567"); + assert_eq!(signal.allow_from, vec!["+15551234567", "+15557654321"]); - /// Settings round-trip through DB map for new gateway fields. - #[test] - fn settings_gateway_fields_db_roundtrip() { - let mut settings = crate::settings::Settings::default(); - settings.channels.gateway_port = Some(4000); - settings.channels.gateway_host = Some("0.0.0.0".to_string()); - settings.channels.gateway_auth_token = Some("tok-abc".to_string()); - settings.channels.gateway_user_id = Some("myuser".to_string()); - settings.channels.cli_enabled = false; - - let map = settings.to_db_map(); - let restored = crate::settings::Settings::from_db_map(&map); - - assert_eq!(restored.channels.gateway_port, Some(4000)); - assert_eq!(restored.channels.gateway_host.as_deref(), Some("0.0.0.0")); assert_eq!( - restored.channels.gateway_auth_token.as_deref(), - Some("tok-abc") + cfg.wasm_channels_dir, + PathBuf::from("/tmp/settings-channels") ); - assert_eq!(restored.channels.gateway_user_id.as_deref(), Some("myuser")); - assert!(!restored.channels.cli_enabled); - } - - /// Invalid boolean env values must produce errors, not silently degrade. - #[test] - fn resolve_rejects_invalid_bool_env() { - let _lock = crate::config::helpers::ENV_MUTEX.lock(); - let settings = crate::settings::Settings::default(); - - // GATEWAY_ENABLED=maybe should error - unsafe { - std::env::set_var("GATEWAY_ENABLED", "maybe"); - std::env::remove_var("HTTP_PORT"); - std::env::remove_var("HTTP_HOST"); - std::env::remove_var("SIGNAL_HTTP_URL"); - std::env::remove_var("CLI_ENABLED"); - std::env::remove_var("WASM_CHANNELS_ENABLED"); - std::env::remove_var("GATEWAY_PORT"); - std::env::remove_var("GATEWAY_HOST"); - std::env::remove_var("GATEWAY_AUTH_TOKEN"); - std::env::remove_var("GATEWAY_USER_ID"); - std::env::remove_var("WASM_CHANNELS_DIR"); - std::env::remove_var("TELEGRAM_OWNER_ID"); - } - let result = ChannelsConfig::resolve(&settings, false); - assert!(result.is_err(), "GATEWAY_ENABLED=maybe should be rejected"); - - // CLI_ENABLED=on should error - unsafe { - std::env::remove_var("GATEWAY_ENABLED"); - std::env::set_var("CLI_ENABLED", "on"); - } - let result = ChannelsConfig::resolve(&settings, false); - assert!(result.is_err(), "CLI_ENABLED=on should be rejected"); - - // WASM_CHANNELS_ENABLED=yes should error - unsafe { - std::env::remove_var("CLI_ENABLED"); - std::env::set_var("WASM_CHANNELS_ENABLED", "yes"); - } - let result = ChannelsConfig::resolve(&settings, false); - assert!( - result.is_err(), - "WASM_CHANNELS_ENABLED=yes should be rejected" - ); - - // Cleanup - unsafe { - std::env::remove_var("WASM_CHANNELS_ENABLED"); - } + assert!(!cfg.wasm_channels_enabled); } } diff --git a/src/config/llm.rs b/src/config/llm.rs index 4ad2439928..64bf4ab8cc 100644 --- a/src/config/llm.rs +++ b/src/config/llm.rs @@ -38,6 +38,8 @@ impl LlmConfig { provider: None, bedrock: None, request_timeout_secs: 120, + cheap_model: None, + smart_routing_cascade: false, } } @@ -168,6 +170,14 @@ impl LlmConfig { let request_timeout_secs = parse_optional_env("LLM_REQUEST_TIMEOUT_SECS", 120)?; + // Generic cheap model (works with any backend). + // Falls back to NearAI-specific cheap_model in provider chain logic. + let cheap_model = optional_env("LLM_CHEAP_MODEL")?; + + // Generic smart routing cascade flag. + // Defaults to true. Overrides NearAI-specific smart_routing_cascade. + let smart_routing_cascade = parse_optional_env("SMART_ROUTING_CASCADE", true)?; + Ok(Self { backend: if is_nearai { "nearai".to_string() @@ -183,6 +193,8 @@ impl LlmConfig { provider, bedrock, request_timeout_secs, + cheap_model, + smart_routing_cascade, }) } diff --git a/src/config/mod.rs b/src/config/mod.rs index 1c81329e11..38c8088050 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -26,7 +26,7 @@ mod tunnel; mod wasm; use std::collections::HashMap; -use std::sync::{LazyLock, Mutex}; +use std::sync::{LazyLock, Mutex, Once}; use crate::error::ConfigError; use crate::settings::Settings; @@ -74,10 +74,12 @@ pub use self::helpers::{env_or_override, set_runtime_env}; /// their data. Whichever runs first initialises the map; the second merges in. static INJECTED_VARS: LazyLock>> = LazyLock::new(|| Mutex::new(HashMap::new())); +static WARNED_EXPLICIT_DEFAULT_OWNER_ID: Once = Once::new(); /// Main configuration for the agent. #[derive(Debug, Clone)] pub struct Config { + pub owner_id: String, pub database: DatabaseConfig, pub llm: LlmConfig, pub embeddings: EmbeddingsConfig, @@ -118,6 +120,7 @@ impl Config { installed_skills_dir: std::path::PathBuf, ) -> Self { Self { + owner_id: "default".to_string(), database: DatabaseConfig { backend: DatabaseBackend::LibSql, url: secrecy::SecretString::from("unused://test".to_string()), @@ -228,13 +231,7 @@ impl Config { pub async fn from_env_with_toml( toml_path: Option<&std::path::Path>, ) -> Result { - let _ = dotenvy::dotenv(); - crate::bootstrap::load_ironclaw_env(); - let mut settings = Settings::load(); - - // Overlay TOML config file (values win over JSON settings) - Self::apply_toml_overlay(&mut settings, toml_path)?; - + let settings = load_bootstrap_settings(toml_path)?; Self::build(&settings).await } @@ -306,16 +303,15 @@ impl Config { /// Build config from settings (shared by from_env and from_db). async fn build(settings: &Settings) -> Result { - // Resolve tunnel first so channels can default to loopback when a - // tunnel handles external exposure (no need to bind 0.0.0.0). - let tunnel = TunnelConfig::resolve(settings)?; + let owner_id = resolve_owner_id(settings)?; Ok(Self { + owner_id: owner_id.clone(), database: DatabaseConfig::resolve()?, llm: LlmConfig::resolve(settings)?, embeddings: EmbeddingsConfig::resolve(settings)?, - channels: ChannelsConfig::resolve(settings, tunnel.is_enabled())?, - tunnel, + tunnel: TunnelConfig::resolve(settings)?, + channels: ChannelsConfig::resolve(settings, &owner_id)?, agent: AgentConfig::resolve(settings)?, safety: resolve_safety_config(settings)?, wasm: WasmConfig::resolve(settings)?, @@ -337,6 +333,43 @@ impl Config { } } +pub(crate) fn load_bootstrap_settings( + toml_path: Option<&std::path::Path>, +) -> Result { + let _ = dotenvy::dotenv(); + crate::bootstrap::load_ironclaw_env(); + + let mut settings = Settings::load(); + Config::apply_toml_overlay(&mut settings, toml_path)?; + Ok(settings) +} + +pub(crate) fn resolve_owner_id(settings: &Settings) -> Result { + let env_owner_id = self::helpers::optional_env("IRONCLAW_OWNER_ID")?; + let settings_owner_id = settings.owner_id.clone(); + let configured_owner_id = env_owner_id.clone().or(settings_owner_id.clone()); + + let owner_id = configured_owner_id + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "default".to_string()); + + if owner_id == "default" + && (env_owner_id.is_some() + || settings_owner_id + .as_deref() + .is_some_and(|value| !value.trim().is_empty())) + { + WARNED_EXPLICIT_DEFAULT_OWNER_ID.call_once(|| { + tracing::warn!( + "IRONCLAW_OWNER_ID resolved to the legacy 'default' scope explicitly; durable state will keep legacy owner behavior" + ); + }); + } + + Ok(owner_id) +} + /// Load API keys from the encrypted secrets store into a thread-safe overlay. /// /// This bridges the gap between secrets stored during onboarding and the diff --git a/src/config/transcription.rs b/src/config/transcription.rs index b0f7606604..da2bac25a0 100644 --- a/src/config/transcription.rs +++ b/src/config/transcription.rs @@ -9,11 +9,15 @@ use crate::settings::Settings; pub struct TranscriptionConfig { /// Whether audio transcription is enabled. pub enabled: bool, - /// Provider: "openai" (default). + /// Provider: "openai" (default) or "chat_completions". pub provider: String, /// OpenAI API key (reuses OPENAI_API_KEY). pub openai_api_key: Option, - /// Model to use (default: "whisper-1"). + /// Explicit transcription API key (overrides provider-specific keys). + pub api_key: Option, + /// LLM API key (reuses LLM_API_KEY, used as fallback for chat_completions). + pub llm_api_key: Option, + /// Model to use (default depends on provider). pub model: String, /// Base URL override for the transcription API. pub base_url: Option, @@ -25,6 +29,8 @@ impl Default for TranscriptionConfig { enabled: false, provider: "openai".to_string(), openai_api_key: None, + api_key: None, + llm_api_key: None, model: "whisper-1".to_string(), base_url: None, } @@ -42,8 +48,15 @@ impl TranscriptionConfig { optional_env("TRANSCRIPTION_PROVIDER")?.unwrap_or_else(|| "openai".to_string()); let openai_api_key = optional_env("OPENAI_API_KEY")?.map(SecretString::from); + let api_key = optional_env("TRANSCRIPTION_API_KEY")?.map(SecretString::from); + let llm_api_key = optional_env("LLM_API_KEY")?.map(SecretString::from); - let model = optional_env("TRANSCRIPTION_MODEL")?.unwrap_or_else(|| "whisper-1".to_string()); + let default_model = match provider.as_str() { + "chat_completions" => "google/gemini-2.0-flash-001", + _ => "whisper-1", + }; + let model = + optional_env("TRANSCRIPTION_MODEL")?.unwrap_or_else(|| default_model.to_string()); let base_url = optional_env("TRANSCRIPTION_BASE_URL")?; @@ -51,29 +64,67 @@ impl TranscriptionConfig { enabled, provider, openai_api_key, + api_key, + llm_api_key, model, base_url, }) } + /// Resolve the API key for the configured provider. + /// + /// Priority: `TRANSCRIPTION_API_KEY` > provider-specific key. + fn resolve_api_key(&self) -> Option<&SecretString> { + self.api_key + .as_ref() + .or_else(|| match self.provider.as_str() { + "chat_completions" => self.llm_api_key.as_ref().or(self.openai_api_key.as_ref()), + _ => self.openai_api_key.as_ref(), + }) + } + /// Create the transcription provider if enabled and configured. pub fn create_provider(&self) -> Option> { if !self.enabled { return None; } - // Currently only OpenAI Whisper is supported; more providers can be - // added here with a match on self.provider. - let api_key = self.openai_api_key.as_ref()?; - tracing::info!(model = %self.model, "Audio transcription enabled via OpenAI Whisper"); + let api_key = self.resolve_api_key()?; - let mut provider = crate::transcription::OpenAiWhisperProvider::new(api_key.clone()) - .with_model(&self.model); + match self.provider.as_str() { + "chat_completions" => { + tracing::info!( + model = %self.model, + "Audio transcription enabled via Chat Completions API" + ); - if let Some(ref base_url) = self.base_url { - provider = provider.with_base_url(base_url); - } + let mut provider = crate::transcription::ChatCompletionsTranscriptionProvider::new( + api_key.clone(), + ) + .with_model(&self.model); + + if let Some(ref base_url) = self.base_url { + provider = provider.with_base_url(base_url); + } - Some(Box::new(provider)) + Some(Box::new(provider)) + } + _ => { + tracing::info!( + model = %self.model, + "Audio transcription enabled via OpenAI Whisper" + ); + + let mut provider = + crate::transcription::OpenAiWhisperProvider::new(api_key.clone()) + .with_model(&self.model); + + if let Some(ref base_url) = self.base_url { + provider = provider.with_base_url(base_url); + } + + Some(Box::new(provider)) + } + } } } diff --git a/src/context/state.rs b/src/context/state.rs index 768e4da6b0..2402fd66b6 100644 --- a/src/context/state.rs +++ b/src/context/state.rs @@ -121,6 +121,9 @@ pub struct JobContext { pub state: JobState, /// User ID that owns this job (for workspace scoping). pub user_id: String, + /// Channel-specific requester/actor ID, when different from the owner scope. + #[serde(skip_serializing_if = "Option::is_none")] + pub requester_id: Option, /// Conversation ID if linked to a conversation. pub conversation_id: Option, /// Job title. @@ -202,6 +205,7 @@ impl JobContext { job_id: Uuid::new_v4(), state: JobState::Pending, user_id: user_id.into(), + requester_id: None, conversation_id: None, title: title.into(), description: description.into(), @@ -233,6 +237,12 @@ impl JobContext { self } + /// Set the channel-specific requester/actor ID. + pub fn with_requester_id(mut self, requester_id: impl Into) -> Self { + self.requester_id = Some(requester_id.into()); + self + } + /// Transition to a new state. pub fn transition_to( &mut self, diff --git a/src/db/libsql/jobs.rs b/src/db/libsql/jobs.rs index 3db3ab3078..208d348b9d 100644 --- a/src/db/libsql/jobs.rs +++ b/src/db/libsql/jobs.rs @@ -106,6 +106,7 @@ impl JobStore for LibSqlBackend { job_id: get_text(&row, 0).parse().unwrap_or_default(), state, user_id: get_text(&row, 6), + requester_id: None, conversation_id: get_opt_text(&row, 1).and_then(|s| s.parse().ok()), title: get_text(&row, 2), description: get_text(&row, 3), diff --git a/src/db/libsql/mod.rs b/src/db/libsql/mod.rs index dcc5a8b5c4..d19089c102 100644 --- a/src/db/libsql/mod.rs +++ b/src/db/libsql/mod.rs @@ -247,6 +247,17 @@ pub(crate) fn opt_text_owned(s: Option) -> libsql::Value { } } +pub(crate) fn normalize_notify_user(value: Option) -> Option { + value.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() || trimmed == "default" { + None + } else { + Some(trimmed.to_string()) + } + }) +} + /// Extract an i64 column, defaulting to 0. pub(crate) fn get_i64(row: &libsql::Row, idx: i32) -> i64 { row.get::(idx).unwrap_or(0) @@ -378,7 +389,7 @@ pub(crate) fn row_to_routine_libsql(row: &libsql::Row) -> Result MakeRustlsConnect { +fn make_rustls_connector() -> Result { let mut root_store = rustls::RootCertStore::empty(); let native = rustls_native_certs::load_native_certs(); for e in &native.errors { @@ -25,10 +34,15 @@ fn make_rustls_connector() -> MakeRustlsConnect { if root_store.is_empty() { tracing::error!("no system root certificates found -- TLS connections will fail"); } - let config = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - MakeRustlsConnect::new(config) + // `--all-features` brings in both aws-lc-rs and ring-backed rustls providers. + // Pick the same ring provider reqwest already uses so postgres TLS setup stays deterministic. + let config = rustls::ClientConfig::builder_with_provider( + rustls::crypto::ring::default_provider().into(), + ) + .with_safe_default_protocol_versions()? + .with_root_certificates(root_store) + .with_no_client_auth(); + Ok(MakeRustlsConnect::new(config)) } /// Create a [`deadpool_postgres::Pool`] with the appropriate TLS connector. @@ -45,12 +59,16 @@ fn make_rustls_connector() -> MakeRustlsConnect { pub fn create_pool( config: &deadpool_postgres::Config, ssl_mode: SslMode, -) -> Result { +) -> Result { match ssl_mode { - SslMode::Disable => config.create_pool(Some(Runtime::Tokio1), NoTls), + SslMode::Disable => config + .create_pool(Some(Runtime::Tokio1), NoTls) + .map_err(CreatePoolError::from), SslMode::Prefer | SslMode::Require => { - let tls = make_rustls_connector(); - config.create_pool(Some(Runtime::Tokio1), tls) + let tls = make_rustls_connector()?; + config + .create_pool(Some(Runtime::Tokio1), tls) + .map_err(CreatePoolError::from) } } } diff --git a/src/error.rs b/src/error.rs index 9e57a358c8..11864de783 100644 --- a/src/error.rs +++ b/src/error.rs @@ -122,6 +122,9 @@ pub enum ChannelError { #[error("Failed to send response on channel {name}: {reason}")] SendFailed { name: String, reason: String }, + #[error("Channel {name} is missing a routing target: {reason}")] + MissingRoutingTarget { name: String, reason: String }, + #[error("Invalid message format: {0}")] InvalidMessage(String), diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index d63ae446a0..471f10cf89 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -3419,6 +3419,7 @@ impl ExtensionManager { Arc::clone(&channel_runtime), Arc::clone(&pairing_store), settings_store, + self.user_id.clone(), ) .with_secrets_store(Arc::clone(&self.secrets)); loader @@ -3435,6 +3436,7 @@ impl ExtensionManager { Arc::clone(&channel_runtime), Arc::clone(&pairing_store), settings_store, + self.user_id.clone(), ) .with_secrets_store(Arc::clone(&self.secrets)); loader @@ -3462,6 +3464,7 @@ impl ExtensionManager { owner_id: Option, ) -> Result { let channel_name = loaded.name().to_string(); + let owner_actor_id = owner_id.map(|id| id.to_string()); let webhook_secret_name = loaded.webhook_secret_name(); let secret_header = loaded.webhook_secret_header().map(|s| s.to_string()); let sig_key_secret_name = loaded.signature_key_secret_name(); @@ -3475,7 +3478,7 @@ impl ExtensionManager { .ok() .map(|s| s.expose().to_string()); - let channel_arc = Arc::new(loaded.channel); + let channel_arc = Arc::new(loaded.channel.with_owner_actor_id(owner_actor_id)); // Inject runtime config (tunnel_url, webhook_secret, owner_id) { @@ -5615,6 +5618,7 @@ mod tests { runtime, prepared, capabilities, + "default", "{}".to_string(), pairing_store, None, diff --git a/src/history/store.rs b/src/history/store.rs index 17fa96fd45..04e3167f28 100644 --- a/src/history/store.rs +++ b/src/history/store.rs @@ -227,6 +227,7 @@ impl Store { job_id: row.get("id"), state, user_id: row.get::<_, String>("user_id"), + requester_id: None, conversation_id: row.get("conversation_id"), title: row.get("title"), description: row.get("description"), diff --git a/src/llm/config.rs b/src/llm/config.rs index 8b7d41c3c8..413f80e209 100644 --- a/src/llm/config.rs +++ b/src/llm/config.rs @@ -138,6 +138,30 @@ pub struct LlmConfig { /// Default: 120. Increase for local LLMs (Ollama, vLLM, LM Studio) that /// need more time for prompt evaluation on consumer hardware. pub request_timeout_secs: u64, + /// Generic cheap/fast model for lightweight tasks (heartbeat, routing, evaluation). + /// Works with any backend. Set via `LLM_CHEAP_MODEL` env var. + /// When set, takes priority over the NearAI-specific `NEARAI_CHEAP_MODEL`. + pub cheap_model: Option, + /// Enable cascade mode for smart routing (retry with primary if cheap model + /// response seems uncertain). Default: true. Set via `SMART_ROUTING_CASCADE`. + pub smart_routing_cascade: bool, +} + +impl LlmConfig { + /// Resolve the effective cheap model name. + /// + /// Resolution order: + /// 1. `LLM_CHEAP_MODEL` (generic, works with any backend) + /// 2. `NEARAI_CHEAP_MODEL` (NearAI-only, backward compatibility) + pub fn cheap_model_name(&self) -> Option<&str> { + self.cheap_model.as_deref().or_else(|| { + if self.backend == "nearai" { + self.nearai.cheap_model.as_deref() + } else { + None + } + }) + } } /// NEAR AI configuration. diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 51309bf37d..3b6b01c472 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -376,32 +376,61 @@ fn create_ollama_from_registry( /// Create a cheap/fast LLM provider for lightweight tasks (heartbeat, routing, evaluation). /// -/// Uses `NEARAI_CHEAP_MODEL` if set, otherwise falls back to the main provider. -/// Currently only supports NEAR AI backend. +/// Resolution order: +/// 1. `LLM_CHEAP_MODEL` (generic, works with any backend) +/// 2. `NEARAI_CHEAP_MODEL` (NearAI-only, backward compatibility) +/// +/// Returns `None` if no cheap model is configured. pub fn create_cheap_llm_provider( config: &LlmConfig, session: Arc, ) -> Result>, LlmError> { - let Some(ref cheap_model) = config.nearai.cheap_model else { + let Some(cheap_model) = config.cheap_model_name() else { return Ok(None); }; - if config.backend != "nearai" { - tracing::warn!( - "NEARAI_CHEAP_MODEL is set but LLM_BACKEND is '{}', not nearai. \ - Cheap model setting will be ignored.", - config.backend - ); - return Ok(None); + create_cheap_provider_for_backend(config, session, cheap_model) +} + +/// Create a cheap provider for a specific backend. +/// +/// Handles backend-specific provider construction: +/// - `nearai` — clones NearAiConfig, swaps model, uses `create_llm_provider_with_config` +/// - `bedrock` — returns error (smart routing not yet supported) +/// - All others — clones `RegistryProviderConfig`, swaps model, uses `create_registry_provider` +fn create_cheap_provider_for_backend( + config: &LlmConfig, + session: Arc, + cheap_model: &str, +) -> Result>, LlmError> { + if config.backend == "nearai" { + let mut cheap_config = config.nearai.clone(); + cheap_config.model = cheap_model.to_string(); + let provider = + create_llm_provider_with_config(&cheap_config, session, config.request_timeout_secs)?; + return Ok(Some(provider)); + } + + if config.backend == "bedrock" { + return Err(LlmError::RequestFailed { + provider: "bedrock".to_string(), + reason: "Smart routing with cheap model is not supported for Bedrock yet".to_string(), + }); } - let mut cheap_config = config.nearai.clone(); - cheap_config.model = cheap_model.clone(); + // Registry-based provider: clone config and swap model + let reg_config = config.provider.as_ref().ok_or_else(|| LlmError::RequestFailed { + provider: config.backend.clone(), + reason: format!( + "Cannot create cheap provider for backend '{}': no registry provider config available", + config.backend + ), + })?; - Ok(Some(Arc::new(NearAiChatProvider::new( - cheap_config, - session, - )?))) + let mut cheap_reg_config = reg_config.clone(); + cheap_reg_config.model = cheap_model.to_string(); + let provider = create_registry_provider(&cheap_reg_config, config.request_timeout_secs)?; + Ok(Some(provider)) } /// Build the full LLM provider chain with all configured wrappers. @@ -449,14 +478,15 @@ pub async fn build_provider_chain( }; // 2. Smart routing (cheap/primary split) - let llm: Arc = if let Some(ref cheap_model) = config.nearai.cheap_model { - let mut cheap_config = config.nearai.clone(); - cheap_config.model = cheap_model.clone(); - let cheap = create_llm_provider_with_config( - &cheap_config, - session.clone(), - config.request_timeout_secs, - )?; + let llm: Arc = if let Some(cheap_model) = config.cheap_model_name() { + let cheap = create_cheap_provider_for_backend(config, session.clone(), cheap_model)? + .ok_or_else(|| LlmError::RequestFailed { + provider: config.backend.clone(), + reason: format!( + "Failed to create cheap provider for model '{cheap_model}' on backend '{}'", + config.backend + ), + })?; let cheap: Arc = if retry_config.max_retries > 0 { Arc::new(RetryProvider::new(cheap, retry_config.clone())) } else { @@ -471,7 +501,7 @@ pub async fn build_provider_chain( llm, cheap, SmartRoutingConfig { - cascade_enabled: config.nearai.smart_routing_cascade, + cascade_enabled: config.smart_routing_cascade, ..SmartRoutingConfig::default() }, )) @@ -600,6 +630,8 @@ mod tests { provider: None, bedrock: None, request_timeout_secs: 120, + cheap_model: None, + smart_routing_cascade: true, } } @@ -614,7 +646,7 @@ mod tests { } #[test] - fn test_create_cheap_llm_provider_creates_provider_when_configured() { + fn test_create_cheap_llm_provider_creates_provider_with_nearai_cheap_model() { let mut config = test_llm_config(); config.nearai.cheap_model = Some("cheap-test-model".to_string()); @@ -628,7 +660,26 @@ mod tests { } #[test] - fn test_create_cheap_llm_provider_ignored_for_non_nearai_backend() { + fn test_create_cheap_llm_provider_generic_overrides_nearai() { + let mut config = test_llm_config(); + config.nearai.cheap_model = Some("nearai-cheap".to_string()); + config.cheap_model = Some("generic-cheap".to_string()); + + let session = Arc::new(SessionManager::new(SessionConfig::default())); + let result = create_cheap_llm_provider(&config, session); + + assert!(result.is_ok()); + let provider = result.unwrap(); + assert!(provider.is_some()); + assert_eq!( + provider.unwrap().model_name(), + "generic-cheap", + "LLM_CHEAP_MODEL should take priority over NEARAI_CHEAP_MODEL" + ); + } + + #[test] + fn test_create_cheap_llm_provider_nearai_cheap_ignored_for_non_nearai_backend() { let mut config = test_llm_config(); config.backend = "openai".to_string(); config.nearai.cheap_model = Some("cheap-test-model".to_string()); @@ -637,6 +688,48 @@ mod tests { let result = create_cheap_llm_provider(&config, session); assert!(result.is_ok()); - assert!(result.unwrap().is_none()); + assert!( + result.unwrap().is_none(), + "NEARAI_CHEAP_MODEL should be ignored when backend is not nearai" + ); + } + + #[test] + fn test_create_cheap_llm_provider_bedrock_returns_error() { + let mut config = test_llm_config(); + config.backend = "bedrock".to_string(); + config.cheap_model = Some("cheap-model".to_string()); + + let session = Arc::new(SessionManager::new(SessionConfig::default())); + let result = create_cheap_llm_provider(&config, session); + + assert!( + result.is_err(), + "Bedrock should return an error for cheap model" + ); + } + + #[test] + fn test_cheap_model_name_resolution() { + // Generic takes priority + let mut config = test_llm_config(); + config.cheap_model = Some("generic".to_string()); + config.nearai.cheap_model = Some("nearai".to_string()); + assert_eq!(config.cheap_model_name(), Some("generic")); + + // NearAI fallback when backend is nearai + let mut config = test_llm_config(); + config.nearai.cheap_model = Some("nearai".to_string()); + assert_eq!(config.cheap_model_name(), Some("nearai")); + + // NearAI ignored for non-nearai backend + let mut config = test_llm_config(); + config.backend = "openai".to_string(); + config.nearai.cheap_model = Some("nearai".to_string()); + assert_eq!(config.cheap_model_name(), None); + + // None when nothing configured + let config = test_llm_config(); + assert_eq!(config.cheap_model_name(), None); } } diff --git a/src/llm/models.rs b/src/llm/models.rs index 7022d3cf6a..daec9df398 100644 --- a/src/llm/models.rs +++ b/src/llm/models.rs @@ -345,5 +345,7 @@ pub(crate) fn build_nearai_model_fetch_config() -> crate::config::LlmConfig { provider: None, bedrock: None, request_timeout_secs: 120, + cheap_model: None, + smart_routing_cascade: false, } } diff --git a/src/main.rs b/src/main.rs index 574616772d..ae864bed9b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -153,7 +153,8 @@ async fn async_main() -> anyhow::Result<()> { provider_only: *provider_only, quick: *quick, }; - let mut wizard = SetupWizard::with_config(config); + let mut wizard = + SetupWizard::try_with_config_and_toml(config, cli.config.as_deref())?; wizard.run().await?; } #[cfg(not(any(feature = "postgres", feature = "libsql")))] @@ -195,10 +196,13 @@ async fn async_main() -> anyhow::Result<()> { { println!("Onboarding needed: {}", reason); println!(); - let mut wizard = SetupWizard::with_config(SetupConfig { - quick: true, - ..Default::default() - }); + let mut wizard = SetupWizard::try_with_config_and_toml( + SetupConfig { + quick: true, + ..Default::default() + }, + cli.config.as_deref(), + )?; wizard.run().await?; } @@ -282,9 +286,12 @@ async fn async_main() -> anyhow::Result<()> { // Create CLI channel let repl_channel = if let Some(ref msg) = cli.message { - Some(ReplChannel::with_message(msg.clone())) + Some(ReplChannel::with_message_for_user( + config.owner_id.clone(), + msg.clone(), + )) } else if config.channels.cli.enabled { - let repl = ReplChannel::new(); + let repl = ReplChannel::with_user_id(config.owner_id.clone()); repl.suppress_banner(); Some(repl) } else { @@ -311,12 +318,7 @@ async fn async_main() -> anyhow::Result<()> { webhook_routes.push(webhooks::routes(ToolWebhookState { tools: Arc::clone(&components.tools), routine_engine: Arc::clone(&shared_routine_engine_slot), - user_id: config - .channels - .gateway - .as_ref() - .map(|g| g.user_id.clone()) - .unwrap_or_else(|| "default".to_string()), + user_id: config.owner_id.clone(), secrets_store: components.secrets_store.clone(), })); @@ -703,6 +705,7 @@ async fn async_main() -> anyhow::Result<()> { .map(|db| Arc::clone(db) as Arc); let deps = AgentDeps { + owner_id: config.owner_id.clone(), store: components.db, llm: components.llm, cheap_llm: components.cheap_llm, @@ -775,6 +778,7 @@ async fn async_main() -> anyhow::Result<()> { let sighup_webhook_server = webhook_server.clone(); let sighup_settings_store_clone = sighup_settings_store.clone(); let sighup_secrets_store = components.secrets_store.clone(); + let sighup_owner_id = config.owner_id.clone(); let mut shutdown_rx = shutdown_tx.subscribe(); tokio::spawn(async move { @@ -805,7 +809,7 @@ async fn async_main() -> anyhow::Result<()> { if let Some(ref secrets_store) = sighup_secrets_store { // Inject HTTP webhook secret from encrypted store if let Ok(webhook_secret) = secrets_store - .get_decrypted("default", "http_webhook_secret") + .get_decrypted(&sighup_owner_id, "http_webhook_secret") .await { // Thread-safe: Uses INJECTED_VARS mutex instead of unsafe std::env::set_var @@ -821,7 +825,7 @@ async fn async_main() -> anyhow::Result<()> { // Reload config (now with secrets injected into environment) let new_config = match &sighup_settings_store_clone { Some(store) => { - ironclaw::config::Config::from_db(store.as_ref(), "default").await + ironclaw::config::Config::from_db(store.as_ref(), &sighup_owner_id).await } None => ironclaw::config::Config::from_env().await, }; diff --git a/src/orchestrator/mod.rs b/src/orchestrator/mod.rs index 5e750ddfcc..b72f90ee49 100644 --- a/src/orchestrator/mod.rs +++ b/src/orchestrator/mod.rs @@ -51,6 +51,15 @@ use crate::db::Database; use crate::llm::LlmProvider; use crate::secrets::SecretsStore; +/// Resolve the orchestrator port from the `ORCHESTRATOR_PORT` environment +/// variable, falling back to 50051. +fn resolve_orchestrator_port() -> u16 { + std::env::var("ORCHESTRATOR_PORT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(50051) +} + /// Result of orchestrator setup, containing all handles needed by the agent. pub struct OrchestratorSetup { pub container_job_manager: Option>, @@ -101,11 +110,12 @@ pub async fn setup_orchestrator( let job_event_tx = Some(tx); let token_store = TokenStore::new(); + let orchestrator_port = resolve_orchestrator_port(); let job_config = ContainerJobConfig { image: config.sandbox.image.clone(), memory_limit_mb: config.sandbox.memory_limit_mb, cpu_shares: config.sandbox.cpu_shares, - orchestrator_port: 50051, + orchestrator_port, claude_code_api_key: std::env::var("ANTHROPIC_API_KEY").ok(), claude_code_oauth_token: crate::config::ClaudeCodeConfig::extract_oauth_token(), claude_code_model: config.claude_code.model.clone(), @@ -127,7 +137,7 @@ pub async fn setup_orchestrator( }; tokio::spawn(async move { - if let Err(e) = OrchestratorApi::start(orchestrator_state, 50051).await { + if let Err(e) = OrchestratorApi::start(orchestrator_state, orchestrator_port).await { tracing::error!("Orchestrator API failed: {}", e); } }); @@ -151,3 +161,40 @@ pub async fn setup_orchestrator( docker_status, } } + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + + /// Serialize access to `ORCHESTRATOR_PORT` env var across test threads. + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + #[test] + fn resolve_orchestrator_port_from_env() { + let _guard = ENV_LOCK.lock().unwrap(); + + // Safety: env-var mutation requires unsafe in edition 2024; + // ENV_LOCK serializes concurrent access from other test threads. + + // Absent env var → default 50051 + unsafe { std::env::remove_var("ORCHESTRATOR_PORT") }; + assert_eq!(resolve_orchestrator_port(), 50051); + + // Valid custom port + unsafe { std::env::set_var("ORCHESTRATOR_PORT", "50052") }; + assert_eq!(resolve_orchestrator_port(), 50052); + + // Non-numeric value → fallback to default + unsafe { std::env::set_var("ORCHESTRATOR_PORT", "not_a_port") }; + assert_eq!(resolve_orchestrator_port(), 50051); + + // Out of u16 range → fallback to default + unsafe { std::env::set_var("ORCHESTRATOR_PORT", "99999") }; + assert_eq!(resolve_orchestrator_port(), 50051); + + // Cleanup + unsafe { std::env::remove_var("ORCHESTRATOR_PORT") }; + } +} diff --git a/src/settings.rs b/src/settings.rs index 2a5b6bbd21..9a0b3942a0 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -16,6 +16,14 @@ pub struct Settings { #[serde(default, alias = "setup_completed")] pub onboard_completed: bool, + /// Stable owner scope for this IronClaw instance. + /// + /// This is bootstrap configuration loaded from env / disk / TOML. We do + /// not persist it in the per-user DB settings table because the DB lookup + /// itself already requires the owner scope to be known. + #[serde(default)] + pub owner_id: Option, + // === Step 1: Database === /// Database backend: "postgres" or "libsql". #[serde(default)] @@ -733,6 +741,10 @@ impl Settings { let mut settings = Self::default(); for (key, value) in map { + if key == "owner_id" { + continue; + } + // Convert the JSONB value to a string for the existing set() method let value_str = match value { serde_json::Value::String(s) => s.clone(), @@ -772,6 +784,7 @@ impl Settings { let mut map = std::collections::HashMap::new(); collect_settings_json(&json, String::new(), &mut map); + map.remove("owner_id"); map } diff --git a/src/setup/wizard.rs b/src/setup/wizard.rs index 9437d8279b..23494d12e9 100644 --- a/src/setup/wizard.rs +++ b/src/setup/wizard.rs @@ -14,6 +14,8 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +#[cfg(feature = "postgres")] +use deadpool_postgres::Config as PoolConfig; use secrecy::{ExposeSecret, SecretString}; use crate::bootstrap::ironclaw_base_dir; @@ -25,8 +27,10 @@ use crate::llm::models::{ build_nearai_model_fetch_config, fetch_anthropic_models, fetch_ollama_models, fetch_openai_compatible_models, fetch_openai_models, }; +#[cfg(test)] +use crate::llm::models::{is_openai_chat_model, sort_openai_models}; use crate::llm::{SessionConfig, SessionManager}; -use crate::secrets::SecretsCrypto; +use crate::secrets::{SecretsCrypto, SecretsStore}; use crate::settings::{KeySource, Settings}; use crate::setup::channels::{ SecretsContext, setup_http, setup_signal, setup_tunnel, setup_wasm_channel, @@ -86,11 +90,14 @@ pub struct SetupConfig { pub struct SetupWizard { config: SetupConfig, settings: Settings, + owner_id: String, session_manager: Option>, - /// Backend-agnostic database trait object (created during setup). - db: Option>, - /// Backend-specific handles for secrets store and other satellite consumers. - db_handles: Option, + /// Database pool (created during setup, postgres only). + #[cfg(feature = "postgres")] + db_pool: Option, + /// libSQL backend (created during setup, libsql only). + #[cfg(feature = "libsql")] + db_backend: Option, /// Secrets crypto (created during setup). secrets_crypto: Option>, /// Cached API key from provider setup (used by model fetcher without env mutation). @@ -98,30 +105,71 @@ pub struct SetupWizard { } impl SetupWizard { - /// Create a new setup wizard. - pub fn new() -> Self { + fn owner_id(&self) -> &str { + &self.owner_id + } + + fn fallback_with_default_owner( + config: SetupConfig, + settings: Settings, + error: &crate::error::ConfigError, + ) -> Self { + tracing::warn!("Falling back to default owner scope for setup wizard: {error}"); Self { - config: SetupConfig::default(), - settings: Settings::default(), + config, + settings, + owner_id: "default".to_string(), session_manager: None, - db: None, - db_handles: None, + #[cfg(feature = "postgres")] + db_pool: None, + #[cfg(feature = "libsql")] + db_backend: None, secrets_crypto: None, llm_api_key: None, } } - /// Create a wizard with custom configuration. - pub fn with_config(config: SetupConfig) -> Self { - Self { + fn from_bootstrap_settings( + config: SetupConfig, + settings: Settings, + ) -> Result { + let owner_id = crate::config::resolve_owner_id(&settings)?; + Ok(Self { config, - settings: Settings::default(), + settings, + owner_id, session_manager: None, - db: None, - db_handles: None, + #[cfg(feature = "postgres")] + db_pool: None, + #[cfg(feature = "libsql")] + db_backend: None, secrets_crypto: None, llm_api_key: None, - } + }) + } + + /// Create a new setup wizard. + pub fn new() -> Self { + let settings = crate::config::load_bootstrap_settings(None).unwrap_or_default(); + Self::from_bootstrap_settings(SetupConfig::default(), settings.clone()).unwrap_or_else( + |e| Self::fallback_with_default_owner(SetupConfig::default(), settings, &e), + ) + } + + /// Create a wizard with custom configuration. + pub fn with_config(config: SetupConfig) -> Self { + let settings = crate::config::load_bootstrap_settings(None).unwrap_or_default(); + Self::from_bootstrap_settings(config.clone(), settings.clone()) + .unwrap_or_else(|e| Self::fallback_with_default_owner(config, settings, &e)) + } + + /// Create a wizard with custom configuration and bootstrap TOML overlay. + pub fn try_with_config_and_toml( + config: SetupConfig, + toml_path: Option<&std::path::Path>, + ) -> Result { + let settings = crate::config::load_bootstrap_settings(toml_path)?; + Self::from_bootstrap_settings(config, settings) } /// Set the session manager (for reusing existing auth). @@ -252,79 +300,115 @@ impl SetupWizard { /// database connection and the wizard's `self.settings` reflects the /// previously saved configuration. async fn reconnect_existing_db(&mut self) -> Result<(), SetupError> { - use crate::config::DatabaseConfig; + // Determine backend from env (set by bootstrap .env loaded in main). + let backend = std::env::var("DATABASE_BACKEND").unwrap_or_else(|_| "postgres".to_string()); + + // Try libsql first if that's the configured backend. + #[cfg(feature = "libsql")] + if backend == "libsql" || backend == "turso" || backend == "sqlite" { + return self.reconnect_libsql().await; + } + + // Try postgres (either explicitly configured or as default). + #[cfg(feature = "postgres")] + { + let _ = &backend; + return self.reconnect_postgres().await; + } + + #[allow(unreachable_code)] + Err(SetupError::Database( + "No database configured. Run full setup first (ironclaw onboard).".to_string(), + )) + } - let db_config = DatabaseConfig::resolve().map_err(|e| { - SetupError::Database(format!( - "Cannot resolve database config. Run full setup first (ironclaw onboard): {}", - e - )) + /// Reconnect to an existing PostgreSQL database and load settings. + #[cfg(feature = "postgres")] + async fn reconnect_postgres(&mut self) -> Result<(), SetupError> { + let url = std::env::var("DATABASE_URL").map_err(|_| { + SetupError::Database( + "DATABASE_URL not set. Run full setup first (ironclaw onboard).".to_string(), + ) })?; - let backend_name = db_config.backend.to_string(); - let (db, handles) = crate::db::connect_with_handles(&db_config) - .await - .map_err(|e| SetupError::Database(format!("Failed to connect: {}", e)))?; + self.test_database_connection_postgres(&url).await?; + self.settings.database_backend = Some("postgres".to_string()); + self.settings.database_url = Some(url.clone()); - // Load existing settings from DB - if let Ok(map) = db.get_all_settings("default").await { - self.settings = Settings::from_db_map(&map); + // Load existing settings from DB, then restore connection fields that + // may not be persisted in the settings map. + if let Some(ref pool) = self.db_pool { + let store = crate::history::Store::from_pool(pool.clone()); + if let Ok(map) = store.get_all_settings(self.owner_id()).await { + self.settings = Settings::from_db_map(&map); + self.settings.database_backend = Some("postgres".to_string()); + self.settings.database_url = Some(url); + } } - // Restore connection fields that may not be persisted in the settings map - self.settings.database_backend = Some(backend_name); - if let Ok(url) = std::env::var("DATABASE_URL") { - self.settings.database_url = Some(url); - } - if let Ok(path) = std::env::var("LIBSQL_PATH") { - self.settings.libsql_path = Some(path); - } else if db_config.libsql_path.is_some() { - self.settings.libsql_path = db_config - .libsql_path - .as_ref() - .map(|p| p.to_string_lossy().to_string()); - } - if let Ok(url) = std::env::var("LIBSQL_URL") { - self.settings.libsql_url = Some(url); - } + Ok(()) + } + + /// Reconnect to an existing libSQL database and load settings. + #[cfg(feature = "libsql")] + async fn reconnect_libsql(&mut self) -> Result<(), SetupError> { + let path = std::env::var("LIBSQL_PATH").unwrap_or_else(|_| { + crate::config::default_libsql_path() + .to_string_lossy() + .to_string() + }); + let turso_url = std::env::var("LIBSQL_URL").ok(); + let turso_token = std::env::var("LIBSQL_AUTH_TOKEN").ok(); - self.db = Some(db); - self.db_handles = Some(handles); + self.test_database_connection_libsql(&path, turso_url.as_deref(), turso_token.as_deref()) + .await?; + + self.settings.database_backend = Some("libsql".to_string()); + self.settings.libsql_path = Some(path.clone()); + if let Some(ref url) = turso_url { + self.settings.libsql_url = Some(url.clone()); + } + + // Load existing settings from DB, then restore connection fields that + // may not be persisted in the settings map. + if let Some(ref db) = self.db_backend { + use crate::db::SettingsStore as _; + if let Ok(map) = db.get_all_settings(self.owner_id()).await { + self.settings = Settings::from_db_map(&map); + self.settings.database_backend = Some("libsql".to_string()); + self.settings.libsql_path = Some(path); + if let Some(url) = turso_url { + self.settings.libsql_url = Some(url); + } + } + } Ok(()) } /// Step 1: Database connection. - /// - /// Determines the backend at runtime (env var, interactive selection, or - /// compile-time default) and runs the appropriate configuration flow. async fn step_database(&mut self) -> Result<(), SetupError> { - use crate::config::{DatabaseBackend, DatabaseConfig}; - - const POSTGRES_AVAILABLE: bool = cfg!(feature = "postgres"); - const LIBSQL_AVAILABLE: bool = cfg!(feature = "libsql"); - - // Determine backend from env var, interactive selection, or default. - let env_backend = std::env::var("DATABASE_BACKEND").ok(); + // When both features are compiled, let the user choose. + // If DATABASE_BACKEND is already set in the environment, respect it. + #[cfg(all(feature = "postgres", feature = "libsql"))] + { + // Check if a backend is already pinned via env var + let env_backend = std::env::var("DATABASE_BACKEND").ok(); - let backend = if let Some(ref raw) = env_backend { - match raw.parse::() { - Ok(b) => b, - Err(_) => { - let fallback = if POSTGRES_AVAILABLE { - DatabaseBackend::Postgres - } else { - DatabaseBackend::LibSql - }; + if let Some(ref backend) = env_backend { + if backend == "libsql" || backend == "turso" || backend == "sqlite" { + return self.step_database_libsql().await; + } + if backend != "postgres" && backend != "postgresql" { print_info(&format!( - "Unknown DATABASE_BACKEND '{}', defaulting to {}", - raw, fallback + "Unknown DATABASE_BACKEND '{}', defaulting to PostgreSQL", + backend )); - fallback } + return self.step_database_postgres().await; } - } else if POSTGRES_AVAILABLE && LIBSQL_AVAILABLE { - // Both features compiled — offer interactive selection. + + // Interactive selection let pre_selected = self.settings.database_backend.as_deref().map(|b| match b { "libsql" | "turso" | "sqlite" => 1, _ => 0, @@ -350,82 +434,88 @@ impl SetupWizard { self.settings.libsql_url = None; } - if choice == 1 { - DatabaseBackend::LibSql - } else { - DatabaseBackend::Postgres + match choice { + 1 => return self.step_database_libsql().await, + _ => return self.step_database_postgres().await, } - } else if LIBSQL_AVAILABLE { - DatabaseBackend::LibSql - } else { - // Only postgres (or neither, but that won't compile anyway). - DatabaseBackend::Postgres - }; + } - // --- Postgres flow --- - if backend == DatabaseBackend::Postgres { - self.settings.database_backend = Some("postgres".to_string()); + #[cfg(all(feature = "postgres", not(feature = "libsql")))] + { + return self.step_database_postgres().await; + } - let existing_url = std::env::var("DATABASE_URL") - .ok() - .or_else(|| self.settings.database_url.clone()); + #[cfg(all(feature = "libsql", not(feature = "postgres")))] + { + return self.step_database_libsql().await; + } + } - if let Some(ref url) = existing_url { - let display_url = mask_password_in_url(url); - print_info(&format!("Existing database URL: {}", display_url)); + /// Step 1 (postgres): Database connection via PostgreSQL URL. + #[cfg(feature = "postgres")] + async fn step_database_postgres(&mut self) -> Result<(), SetupError> { + self.settings.database_backend = Some("postgres".to_string()); - if confirm("Use this database?", true).map_err(SetupError::Io)? { - let config = DatabaseConfig::from_postgres_url(url, 5); - if let Err(e) = self.test_database_connection(&config).await { - print_error(&format!("Connection failed: {}", e)); - print_info("Let's configure a new database URL."); - } else { - print_success("Database connection successful"); - self.settings.database_url = Some(url.clone()); - return Ok(()); - } + let existing_url = std::env::var("DATABASE_URL") + .ok() + .or_else(|| self.settings.database_url.clone()); + + if let Some(ref url) = existing_url { + let display_url = mask_password_in_url(url); + print_info(&format!("Existing database URL: {}", display_url)); + + if confirm("Use this database?", true).map_err(SetupError::Io)? { + if let Err(e) = self.test_database_connection_postgres(url).await { + print_error(&format!("Connection failed: {}", e)); + print_info("Let's configure a new database URL."); + } else { + print_success("Database connection successful"); + self.settings.database_url = Some(url.clone()); + return Ok(()); } } + } - println!(); - print_info("Enter your PostgreSQL connection URL."); - print_info("Format: postgres://user:password@host:port/database"); - println!(); - - loop { - let url = input("Database URL").map_err(SetupError::Io)?; + println!(); + print_info("Enter your PostgreSQL connection URL."); + print_info("Format: postgres://user:password@host:port/database"); + println!(); - if url.is_empty() { - print_error("Database URL is required."); - continue; - } + loop { + let url = input("Database URL").map_err(SetupError::Io)?; - print_info("Testing connection..."); - let config = DatabaseConfig::from_postgres_url(&url, 5); - match self.test_database_connection(&config).await { - Ok(()) => { - print_success("Database connection successful"); + if url.is_empty() { + print_error("Database URL is required."); + continue; + } - if confirm("Run database migrations?", true).map_err(SetupError::Io)? { - self.run_migrations().await?; - } + print_info("Testing connection..."); + match self.test_database_connection_postgres(&url).await { + Ok(()) => { + print_success("Database connection successful"); - self.settings.database_url = Some(url); - return Ok(()); + if confirm("Run database migrations?", true).map_err(SetupError::Io)? { + self.run_migrations_postgres().await?; } - Err(e) => { - print_error(&format!("Connection failed: {}", e)); - if !confirm("Try again?", true).map_err(SetupError::Io)? { - return Err(SetupError::Database( - "Database connection failed".to_string(), - )); - } + + self.settings.database_url = Some(url); + return Ok(()); + } + Err(e) => { + print_error(&format!("Connection failed: {}", e)); + if !confirm("Try again?", true).map_err(SetupError::Io)? { + return Err(SetupError::Database( + "Database connection failed".to_string(), + )); } } } } + } - // --- libSQL flow --- + /// Step 1 (libsql): Database connection via local file or Turso remote replica. + #[cfg(feature = "libsql")] + async fn step_database_libsql(&mut self) -> Result<(), SetupError> { self.settings.database_backend = Some("libsql".to_string()); let default_path = crate::config::default_libsql_path(); @@ -444,12 +534,14 @@ impl SetupWizard { .or_else(|| self.settings.libsql_url.clone()); let turso_token = std::env::var("LIBSQL_AUTH_TOKEN").ok(); - let config = DatabaseConfig::from_libsql_path( - path, - turso_url.as_deref(), - turso_token.as_deref(), - ); - match self.test_database_connection(&config).await { + match self + .test_database_connection_libsql( + path, + turso_url.as_deref(), + turso_token.as_deref(), + ) + .await + { Ok(()) => { print_success("Database connection successful"); self.settings.libsql_path = Some(path.clone()); @@ -508,17 +600,15 @@ impl SetupWizard { }; print_info("Testing connection..."); - let config = DatabaseConfig::from_libsql_path( - &db_path, - turso_url.as_deref(), - turso_token.as_deref(), - ); - match self.test_database_connection(&config).await { + match self + .test_database_connection_libsql(&db_path, turso_url.as_deref(), turso_token.as_deref()) + .await + { Ok(()) => { print_success("Database connection successful"); // Always run migrations for libsql (they're idempotent) - self.run_migrations().await?; + self.run_migrations_libsql().await?; self.settings.libsql_path = Some(db_path); if let Some(url) = turso_url { @@ -530,39 +620,155 @@ impl SetupWizard { } } - /// Test database connection using the db module factory. + /// Test PostgreSQL connection and store the pool. /// - /// Connects without running migrations and validates PostgreSQL - /// prerequisites (version, pgvector) when using the postgres backend. - async fn test_database_connection( + /// After connecting, validates: + /// 1. PostgreSQL version >= 15 (required for pgvector compatibility) + /// 2. pgvector extension is available (required for embeddings/vector search) + #[cfg(feature = "postgres")] + async fn test_database_connection_postgres(&mut self, url: &str) -> Result<(), SetupError> { + let mut cfg = PoolConfig::new(); + cfg.url = Some(url.to_string()); + cfg.pool = Some(deadpool_postgres::PoolConfig { + max_size: 5, + ..Default::default() + }); + + let pool = crate::db::tls::create_pool(&cfg, crate::config::SslMode::from_env()) + .map_err(|e| SetupError::Database(format!("Failed to create pool: {}", e)))?; + + let client = pool + .get() + .await + .map_err(|e| SetupError::Database(format!("Failed to connect: {}", e)))?; + + // Check PostgreSQL server version (need 15+ for pgvector) + let version_row = client + .query_one("SHOW server_version", &[]) + .await + .map_err(|e| SetupError::Database(format!("Failed to query server version: {}", e)))?; + let version_str: &str = version_row.get(0); + let major_version = version_str + .split('.') + .next() + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + const MIN_PG_MAJOR_VERSION: u32 = 15; + + if major_version < MIN_PG_MAJOR_VERSION { + return Err(SetupError::Database(format!( + "PostgreSQL {} detected. IronClaw requires PostgreSQL {} or later for pgvector support.\n\ + Upgrade: https://www.postgresql.org/download/", + version_str, MIN_PG_MAJOR_VERSION + ))); + } + + // Check if pgvector extension is available + let pgvector_row = client + .query_opt( + "SELECT 1 FROM pg_available_extensions WHERE name = 'vector'", + &[], + ) + .await + .map_err(|e| { + SetupError::Database(format!("Failed to check pgvector availability: {}", e)) + })?; + + if pgvector_row.is_none() { + return Err(SetupError::Database(format!( + "pgvector extension not found on your PostgreSQL server.\n\n\ + Install it:\n \ + macOS: brew install pgvector\n \ + Ubuntu: apt install postgresql-{0}-pgvector\n \ + Docker: use the pgvector/pgvector:pg{0} image\n \ + Source: https://github.com/pgvector/pgvector#installation\n\n\ + Then restart PostgreSQL and re-run: ironclaw onboard", + major_version + ))); + } + + self.db_pool = Some(pool); + Ok(()) + } + + /// Test libSQL connection and store the backend. + #[cfg(feature = "libsql")] + async fn test_database_connection_libsql( &mut self, - config: &crate::config::DatabaseConfig, + path: &str, + turso_url: Option<&str>, + turso_token: Option<&str>, ) -> Result<(), SetupError> { - let (db, handles) = crate::db::connect_without_migrations(config) - .await - .map_err(|e| SetupError::Database(e.to_string()))?; + use crate::db::libsql::LibSqlBackend; + use std::path::Path; + + let db_path = Path::new(path); + + let backend = if let (Some(url), Some(token)) = (turso_url, turso_token) { + LibSqlBackend::new_remote_replica(db_path, url, token) + .await + .map_err(|e| SetupError::Database(format!("Failed to connect: {}", e)))? + } else { + LibSqlBackend::new_local(db_path) + .await + .map_err(|e| SetupError::Database(format!("Failed to open database: {}", e)))? + }; + + self.db_backend = Some(backend); + Ok(()) + } + + /// Run PostgreSQL migrations. + #[cfg(feature = "postgres")] + async fn run_migrations_postgres(&self) -> Result<(), SetupError> { + if let Some(ref pool) = self.db_pool { + use refinery::embed_migrations; + embed_migrations!("migrations"); + + if !self.config.quick { + print_info("Running migrations..."); + } + tracing::debug!("Running PostgreSQL migrations..."); + + let mut client = pool + .get() + .await + .map_err(|e| SetupError::Database(format!("Pool error: {}", e)))?; + + migrations::runner() + .run_async(&mut **client) + .await + .map_err(|e| SetupError::Database(format!("Migration failed: {}", e)))?; - self.db = Some(db); - self.db_handles = Some(handles); + if !self.config.quick { + print_success("Migrations applied"); + } + tracing::debug!("PostgreSQL migrations applied"); + } Ok(()) } - /// Run database migrations on the current connection. - async fn run_migrations(&self) -> Result<(), SetupError> { - if let Some(ref db) = self.db { + /// Run libSQL migrations. + #[cfg(feature = "libsql")] + async fn run_migrations_libsql(&self) -> Result<(), SetupError> { + if let Some(ref backend) = self.db_backend { + use crate::db::Database; + if !self.config.quick { print_info("Running migrations..."); } - tracing::debug!("Running database migrations..."); + tracing::debug!("Running libSQL migrations..."); - db.run_migrations() + backend + .run_migrations() .await .map_err(|e| SetupError::Database(format!("Migration failed: {}", e)))?; if !self.config.quick { print_success("Migrations applied"); } - tracing::debug!("Database migrations applied"); + tracing::debug!("libSQL migrations applied"); } Ok(()) } @@ -579,19 +785,20 @@ impl SetupWizard { return Ok(()); } - // Try to retrieve existing key from keychain via resolve_master_key - // (checks env var first, then keychain). We skip the env var case - // above, so this will only find a keychain key here. + // Try to retrieve existing key from keychain. We use get_master_key() + // instead of has_master_key() so we can cache the key bytes and build + // SecretsCrypto eagerly, avoiding redundant keychain accesses later + // (each access triggers macOS system dialogs). print_info("Checking OS keychain for existing master key..."); if let Ok(keychain_key_bytes) = crate::secrets::keychain::get_master_key().await { let key_hex: String = keychain_key_bytes .iter() .map(|b| format!("{:02x}", b)) .collect(); - self.secrets_crypto = Some( - crate::secrets::crypto_from_hex(&key_hex) + self.secrets_crypto = Some(Arc::new( + SecretsCrypto::new(SecretString::from(key_hex)) .map_err(|e| SetupError::Config(e.to_string()))?, - ); + )); print_info("Existing master key found in OS keychain."); if confirm("Use existing keychain key?", true).map_err(SetupError::Io)? { @@ -630,11 +837,12 @@ impl SetupWizard { SetupError::Config(format!("Failed to store in keychain: {}", e)) })?; + // Also create crypto instance let key_hex: String = key.iter().map(|b| format!("{:02x}", b)).collect(); - self.secrets_crypto = Some( - crate::secrets::crypto_from_hex(&key_hex) + self.secrets_crypto = Some(Arc::new( + SecretsCrypto::new(SecretString::from(key_hex)) .map_err(|e| SetupError::Config(e.to_string()))?, - ); + )); self.settings.secrets_master_key_source = KeySource::Keychain; print_success("Master key generated and stored in OS keychain"); @@ -645,10 +853,10 @@ impl SetupWizard { // Initialize crypto so subsequent wizard steps (channel setup, // API key storage) can encrypt secrets immediately. - self.secrets_crypto = Some( - crate::secrets::crypto_from_hex(&key_hex) + self.secrets_crypto = Some(Arc::new( + SecretsCrypto::new(SecretString::from(key_hex.clone())) .map_err(|e| SetupError::Config(e.to_string()))?, - ); + )); // Make visible to optional_env() for any subsequent config resolution. crate::config::inject_single_var("SECRETS_MASTER_KEY", &key_hex); @@ -681,22 +889,16 @@ impl SetupWizard { /// standard path. Falls back to the interactive `step_database()` only when /// just the postgres feature is compiled (can't auto-default postgres). async fn auto_setup_database(&mut self) -> Result<(), SetupError> { - use crate::config::{DatabaseBackend, DatabaseConfig}; - - const POSTGRES_AVAILABLE: bool = cfg!(feature = "postgres"); - const LIBSQL_AVAILABLE: bool = cfg!(feature = "libsql"); - + // If DATABASE_URL or LIBSQL_PATH already set, respect existing config + #[cfg(feature = "postgres")] let env_backend = std::env::var("DATABASE_BACKEND").ok(); - // If DATABASE_BACKEND=postgres and DATABASE_URL exists: connect+migrate + #[cfg(feature = "postgres")] if let Some(ref backend) = env_backend - && let Ok(DatabaseBackend::Postgres) = backend.parse::() + && (backend == "postgres" || backend == "postgresql") { if let Ok(url) = std::env::var("DATABASE_URL") { print_info("Using existing PostgreSQL configuration"); - let config = DatabaseConfig::from_postgres_url(&url, 5); - self.test_database_connection(&config).await?; - self.run_migrations().await?; self.settings.database_backend = Some("postgres".to_string()); self.settings.database_url = Some(url); return Ok(()); @@ -705,23 +907,17 @@ impl SetupWizard { return self.step_database().await; } - // If DATABASE_URL exists (no explicit backend): connect+migrate as postgres, - // but only when the postgres feature is actually compiled in. - if POSTGRES_AVAILABLE - && env_backend.is_none() - && let Ok(url) = std::env::var("DATABASE_URL") - { + #[cfg(feature = "postgres")] + if let Ok(url) = std::env::var("DATABASE_URL") { print_info("Using existing PostgreSQL configuration"); - let config = DatabaseConfig::from_postgres_url(&url, 5); - self.test_database_connection(&config).await?; - self.run_migrations().await?; self.settings.database_backend = Some("postgres".to_string()); self.settings.database_url = Some(url); return Ok(()); } - // Auto-default to libsql if available - if LIBSQL_AVAILABLE { + // Auto-default to libsql if the feature is compiled + #[cfg(feature = "libsql")] + { self.settings.database_backend = Some("libsql".to_string()); let existing_path = std::env::var("LIBSQL_PATH") @@ -737,13 +933,14 @@ impl SetupWizard { let turso_url = std::env::var("LIBSQL_URL").ok(); let turso_token = std::env::var("LIBSQL_AUTH_TOKEN").ok(); - let config = DatabaseConfig::from_libsql_path( + self.test_database_connection_libsql( &db_path, turso_url.as_deref(), turso_token.as_deref(), - ); - self.test_database_connection(&config).await?; - self.run_migrations().await?; + ) + .await?; + + self.run_migrations_libsql().await?; self.settings.libsql_path = Some(db_path.clone()); if let Some(url) = turso_url { @@ -755,7 +952,10 @@ impl SetupWizard { } // Only postgres feature compiled — can't auto-default, use interactive - self.step_database().await + #[allow(unreachable_code)] + { + self.step_database().await + } } /// Auto-setup security with zero prompts (quick mode). @@ -764,23 +964,26 @@ impl SetupWizard { /// key if available, otherwise generates and stores one automatically /// (keychain on macOS, env var fallback). async fn auto_setup_security(&mut self) -> Result<(), SetupError> { - // Try resolving an existing key from env var or keychain - if let Some(key_hex) = crate::secrets::resolve_master_key().await { - self.secrets_crypto = Some( - crate::secrets::crypto_from_hex(&key_hex) + // Check env var first + if std::env::var("SECRETS_MASTER_KEY").is_ok() { + self.settings.secrets_master_key_source = KeySource::Env; + print_success("Security configured (env var)"); + return Ok(()); + } + + // Try existing keychain key (no prompts — get_master_key may show + // OS dialogs on macOS, but that's unavoidable for keychain access) + if let Ok(keychain_key_bytes) = crate::secrets::keychain::get_master_key().await { + let key_hex: String = keychain_key_bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect(); + self.secrets_crypto = Some(Arc::new( + SecretsCrypto::new(SecretString::from(key_hex)) .map_err(|e| SetupError::Config(e.to_string()))?, - ); - // Determine source: env var or keychain (filter empty to match resolve_master_key) - let (source, label) = if std::env::var("SECRETS_MASTER_KEY") - .ok() - .is_some_and(|v| !v.is_empty()) - { - (KeySource::Env, "env var") - } else { - (KeySource::Keychain, "keychain") - }; - self.settings.secrets_master_key_source = source; - print_success(&format!("Security configured ({})", label)); + )); + self.settings.secrets_master_key_source = KeySource::Keychain; + print_success("Security configured (keychain)"); return Ok(()); } @@ -792,10 +995,10 @@ impl SetupWizard { .is_ok() { let key_hex: String = key.iter().map(|b| format!("{:02x}", b)).collect(); - self.secrets_crypto = Some( - crate::secrets::crypto_from_hex(&key_hex) + self.secrets_crypto = Some(Arc::new( + SecretsCrypto::new(SecretString::from(key_hex)) .map_err(|e| SetupError::Config(e.to_string()))?, - ); + )); self.settings.secrets_master_key_source = KeySource::Keychain; print_success("Master key stored in OS keychain"); return Ok(()); @@ -803,10 +1006,10 @@ impl SetupWizard { // Keychain unavailable — fall back to env var mode let key_hex = crate::secrets::keychain::generate_master_key_hex(); - self.secrets_crypto = Some( - crate::secrets::crypto_from_hex(&key_hex) + self.secrets_crypto = Some(Arc::new( + SecretsCrypto::new(SecretString::from(key_hex.clone())) .map_err(|e| SetupError::Config(e.to_string()))?, - ); + )); crate::config::inject_single_var("SECRETS_MASTER_KEY", &key_hex); self.settings.secrets_master_key_hex = Some(key_hex); self.settings.secrets_master_key_source = KeySource::Env; @@ -1677,27 +1880,74 @@ impl SetupWizard { /// Initialize secrets context for channel setup. async fn init_secrets_context(&mut self) -> Result { - // Get crypto (should be set from step 2, or resolve from keychain/env) + // Get crypto (should be set from step 2, or load from keychain/env) let crypto = if let Some(ref c) = self.secrets_crypto { Arc::clone(c) } else { - let key_hex = crate::secrets::resolve_master_key().await.ok_or_else(|| { - SetupError::Config( + // Try to load master key from keychain or env + let key = if let Ok(env_key) = std::env::var("SECRETS_MASTER_KEY") { + env_key + } else if let Ok(keychain_key) = crate::secrets::keychain::get_master_key().await { + keychain_key.iter().map(|b| format!("{:02x}", b)).collect() + } else { + return Err(SetupError::Config( "Secrets not configured. Run full setup or set SECRETS_MASTER_KEY.".to_string(), - ) - })?; + )); + }; - let crypto = crate::secrets::crypto_from_hex(&key_hex) - .map_err(|e| SetupError::Config(e.to_string()))?; + let crypto = Arc::new( + SecretsCrypto::new(SecretString::from(key)) + .map_err(|e| SetupError::Config(e.to_string()))?, + ); self.secrets_crypto = Some(Arc::clone(&crypto)); crypto }; - // Create secrets store from existing database handles - if let Some(ref handles) = self.db_handles - && let Some(store) = crate::secrets::create_secrets_store(Arc::clone(&crypto), handles) - { - return Ok(SecretsContext::from_store(store, "default")); + // Create backend-appropriate secrets store. + // Use runtime dispatch based on the user's selected backend. + // Default to whichever backend is compiled in. When only libsql is + // available, we must not default to "postgres" or we'd skip store creation. + let default_backend = { + #[cfg(feature = "postgres")] + { + "postgres" + } + #[cfg(not(feature = "postgres"))] + { + "libsql" + } + }; + let selected_backend = self + .settings + .database_backend + .as_deref() + .unwrap_or(default_backend); + + match selected_backend { + #[cfg(feature = "libsql")] + "libsql" | "turso" | "sqlite" => { + if let Some(store) = self.create_libsql_secrets_store(&crypto)? { + return Ok(SecretsContext::from_store(store, self.owner_id())); + } + // Fallback to postgres if libsql store creation returned None + #[cfg(feature = "postgres")] + if let Some(store) = self.create_postgres_secrets_store(&crypto).await? { + return Ok(SecretsContext::from_store(store, self.owner_id())); + } + } + #[cfg(feature = "postgres")] + _ => { + if let Some(store) = self.create_postgres_secrets_store(&crypto).await? { + return Ok(SecretsContext::from_store(store, self.owner_id())); + } + // Fallback to libsql if postgres store creation returned None + #[cfg(feature = "libsql")] + if let Some(store) = self.create_libsql_secrets_store(&crypto)? { + return Ok(SecretsContext::from_store(store, self.owner_id())); + } + } + #[cfg(not(feature = "postgres"))] + _ => {} } Err(SetupError::Config( @@ -1705,6 +1955,62 @@ impl SetupWizard { )) } + /// Create a PostgreSQL secrets store from the current pool. + #[cfg(feature = "postgres")] + async fn create_postgres_secrets_store( + &mut self, + crypto: &Arc, + ) -> Result>, SetupError> { + let pool = if let Some(ref p) = self.db_pool { + p.clone() + } else { + // Fall back to creating one from settings/env + let url = self + .settings + .database_url + .clone() + .or_else(|| std::env::var("DATABASE_URL").ok()); + + if let Some(url) = url { + self.test_database_connection_postgres(&url).await?; + self.run_migrations_postgres().await?; + match self.db_pool.clone() { + Some(pool) => pool, + None => { + return Err(SetupError::Database( + "Database pool not initialized after connection test".to_string(), + )); + } + } + } else { + return Ok(None); + } + }; + + let store: Arc = Arc::new(crate::secrets::PostgresSecretsStore::new( + pool, + Arc::clone(crypto), + )); + Ok(Some(store)) + } + + /// Create a libSQL secrets store from the current backend. + #[cfg(feature = "libsql")] + fn create_libsql_secrets_store( + &self, + crypto: &Arc, + ) -> Result>, SetupError> { + if let Some(ref backend) = self.db_backend { + let store: Arc = Arc::new(crate::secrets::LibSqlSecretsStore::new( + backend.shared_db(), + Arc::clone(crypto), + )); + Ok(Some(store)) + } else { + Ok(None) + } + } + /// Step 6: Channel configuration. async fn step_channels(&mut self) -> Result<(), SetupError> { // First, configure tunnel (shared across all channels that need webhooks) @@ -2222,15 +2528,45 @@ impl SetupWizard { /// connection is available yet (e.g., before Step 1 completes). async fn persist_settings(&self) -> Result { let db_map = self.settings.to_db_map(); + let saved = false; + + #[cfg(feature = "postgres")] + let saved = if !saved { + if let Some(ref pool) = self.db_pool { + let store = crate::history::Store::from_pool(pool.clone()); + store + .set_all_settings(self.owner_id(), &db_map) + .await + .map_err(|e| { + SetupError::Database(format!("Failed to save settings to database: {}", e)) + })?; + true + } else { + false + } + } else { + saved + }; - if let Some(ref db) = self.db { - db.set_all_settings("default", &db_map).await.map_err(|e| { - SetupError::Database(format!("Failed to save settings to database: {}", e)) - })?; - Ok(true) + #[cfg(feature = "libsql")] + let saved = if !saved { + if let Some(ref backend) = self.db_backend { + use crate::db::SettingsStore as _; + backend + .set_all_settings(self.owner_id(), &db_map) + .await + .map_err(|e| { + SetupError::Database(format!("Failed to save settings to database: {}", e)) + })?; + true + } else { + false + } } else { - Ok(false) - } + saved + }; + + Ok(saved) } /// Write bootstrap environment variables to `~/.ironclaw/.env`. @@ -2406,12 +2742,28 @@ impl SetupWizard { Err(_) => return, }; - if let Some(ref db) = self.db { - if let Err(e) = db - .set_setting("default", "nearai.session_token", &value) + #[cfg(feature = "postgres")] + if let Some(ref pool) = self.db_pool { + let store = crate::history::Store::from_pool(pool.clone()); + if let Err(e) = store + .set_setting(self.owner_id(), "nearai.session_token", &value) + .await + { + tracing::debug!("Could not persist session token to postgres: {}", e); + } else { + tracing::debug!("Session token persisted to database"); + return; + } + } + + #[cfg(feature = "libsql")] + if let Some(ref backend) = self.db_backend { + use crate::db::SettingsStore as _; + if let Err(e) = backend + .set_setting(self.owner_id(), "nearai.session_token", &value) .await { - tracing::debug!("Could not persist session token to database: {}", e); + tracing::debug!("Could not persist session token to libsql: {}", e); } else { tracing::debug!("Session token persisted to database"); } @@ -2448,19 +2800,58 @@ impl SetupWizard { /// prefers the `other` argument's non-default values. Without this, /// stale DB values would overwrite fresh user choices. async fn try_load_existing_settings(&mut self) { - if let Some(ref db) = self.db { - match db.get_all_settings("default").await { - Ok(db_map) if !db_map.is_empty() => { - let existing = Settings::from_db_map(&db_map); - self.settings.merge_from(&existing); - tracing::info!("Loaded {} existing settings from database", db_map.len()); + let loaded = false; + + #[cfg(feature = "postgres")] + let loaded = if !loaded { + if let Some(ref pool) = self.db_pool { + let store = crate::history::Store::from_pool(pool.clone()); + match store.get_all_settings(self.owner_id()).await { + Ok(db_map) if !db_map.is_empty() => { + let existing = Settings::from_db_map(&db_map); + self.settings.merge_from(&existing); + tracing::info!("Loaded {} existing settings from database", db_map.len()); + true + } + Ok(_) => false, + Err(e) => { + tracing::debug!("Could not load existing settings: {}", e); + false + } } - Ok(_) => {} - Err(e) => { - tracing::debug!("Could not load existing settings: {}", e); + } else { + false + } + } else { + loaded + }; + + #[cfg(feature = "libsql")] + let loaded = if !loaded { + if let Some(ref backend) = self.db_backend { + use crate::db::SettingsStore as _; + match backend.get_all_settings(self.owner_id()).await { + Ok(db_map) if !db_map.is_empty() => { + let existing = Settings::from_db_map(&db_map); + self.settings.merge_from(&existing); + tracing::info!("Loaded {} existing settings from database", db_map.len()); + true + } + Ok(_) => false, + Err(e) => { + tracing::debug!("Could not load existing settings: {}", e); + false + } } + } else { + false } - } + } else { + loaded + }; + + // Suppress unused variable warning when only one backend is compiled. + let _ = loaded; } /// Save settings to the database and `~/.ironclaw/.env`, then print summary. @@ -2610,6 +3001,7 @@ impl Default for SetupWizard { } /// Mask password in a database URL for display. +#[cfg(feature = "postgres")] fn mask_password_in_url(url: &str) -> String { // URL format: scheme://user:password@host/database // Find "://" to locate start of credentials @@ -2911,12 +3303,13 @@ async fn install_selected_bundled_channels( #[cfg(test)] mod tests { use std::collections::HashSet; + #[cfg(unix)] + use std::ffi::OsString; use tempfile::tempdir; use super::*; use crate::config::helpers::ENV_MUTEX; - use crate::llm::models::{is_openai_chat_model, sort_openai_models}; #[test] fn test_wizard_creation() { @@ -2938,6 +3331,53 @@ mod tests { } #[test] + fn test_wizard_owner_id_uses_resolved_env_scope() { + let _guard = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()); + let _owner = EnvGuard::set("IRONCLAW_OWNER_ID", " wizard-owner "); + + let wizard = SetupWizard::new(); + assert_eq!(wizard.owner_id(), "wizard-owner"); // safety: test-only assertion + } + + #[test] + fn test_wizard_owner_id_uses_toml_scope() { + let _guard = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()); + let _owner = EnvGuard::clear("IRONCLAW_OWNER_ID"); + let dir = tempdir().unwrap(); // safety: test-only tempdir setup + let path = dir.path().join("config.toml"); + std::fs::write(&path, "owner_id = \"toml-owner\"\n").unwrap(); // safety: test-only fixture write + + let wizard = SetupWizard::try_with_config_and_toml(Default::default(), Some(&path)) + .expect("wizard should load owner_id from TOML"); // safety: test-only assertion + assert_eq!(wizard.owner_id(), "toml-owner"); // safety: test-only assertion + } + + #[test] + #[cfg(unix)] + fn test_try_with_config_and_toml_propagates_invalid_owner_env() { + use std::os::unix::ffi::OsStringExt; + + let _guard = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()); + let original = std::env::var_os("IRONCLAW_OWNER_ID"); + unsafe { + std::env::set_var("IRONCLAW_OWNER_ID", OsString::from_vec(vec![0x66, 0x80])); + } + + let result = SetupWizard::try_with_config_and_toml(Default::default(), None); + + unsafe { + if let Some(value) = original { + std::env::set_var("IRONCLAW_OWNER_ID", value); + } else { + std::env::remove_var("IRONCLAW_OWNER_ID"); + } + } + + assert!(result.is_err()); // safety: test-only assertion + } + + #[test] + #[cfg(feature = "postgres")] fn test_mask_password_in_url() { assert_eq!( mask_password_in_url("postgres://user:secret@localhost/db"), @@ -2981,12 +3421,12 @@ mod tests { return; } - let dir = tempdir().unwrap(); + let dir = tempdir().unwrap(); // safety: test-only tempdir setup let installed = HashSet::::new(); install_missing_bundled_channels(dir.path(), &installed) .await - .unwrap(); + .unwrap(); // safety: test-only assertion assert!(dir.path().join("telegram.wasm").exists()); assert!(dir.path().join("telegram.capabilities.json").exists()); @@ -3088,7 +3528,7 @@ mod tests { #[tokio::test] async fn test_discover_wasm_channels_empty_dir() { - let dir = tempdir().unwrap(); + let dir = tempdir().unwrap(); // safety: test-only tempdir setup let channels = discover_wasm_channels(dir.path()).await; assert!(channels.is_empty()); } diff --git a/src/testing/mod.rs b/src/testing/mod.rs index 33702e679f..ff522e3ad2 100644 --- a/src/testing/mod.rs +++ b/src/testing/mod.rs @@ -439,6 +439,7 @@ impl TestHarnessBuilder { }; let deps = AgentDeps { + owner_id: "default".to_string(), store: Some(Arc::clone(&db)), llm, cheap_llm: None, @@ -1077,7 +1078,7 @@ mod tests { }, notify: NotifyConfig { channel: None, - user: "user1".to_string(), + user: Some("user1".to_string()), on_attention: true, on_failure: true, on_success: false, @@ -1210,7 +1211,7 @@ mod tests { }, notify: NotifyConfig { channel: None, - user: "user1".to_string(), + user: Some("user1".to_string()), on_attention: false, on_failure: false, on_success: false, diff --git a/src/tools/builtin/message.rs b/src/tools/builtin/message.rs index 53d16e78f1..b150c951e1 100644 --- a/src/tools/builtin/message.rs +++ b/src/tools/builtin/message.rs @@ -129,21 +129,28 @@ impl Tool for MessageTool { .map(|c| c.to_string()) }; - // Get target: use param → conversation default → job metadata + // Get target: use param → conversation default → job metadata → owner scope + // fallback when a specific channel is known. let target = if let Some(t) = params.get("target").and_then(|v| v.as_str()) { - t.to_string() + Some(t.to_string()) } else if let Some(t) = self .default_target .read() .unwrap_or_else(|e| e.into_inner()) .clone() { - t + Some(t) } else if let Some(t) = ctx.metadata.get("notify_user").and_then(|v| v.as_str()) { - t.to_string() + Some(t.to_string()) + } else if channel.is_some() { + Some(ctx.user_id.clone()) } else { + None + }; + + let Some(target) = target else { return Err(ToolError::ExecutionFailed( - "No target specified and no active conversation. Provide target parameter." + "No target specified and no channel-scoped routing target could be resolved. Provide target parameter." .to_string(), )); }; @@ -659,6 +666,31 @@ mod tests { ); } + #[tokio::test] + async fn message_tool_falls_back_to_ctx_user_when_channel_known() { + // Regression for owner-scoped notifications: a channel can be known + // even when the concrete delivery target is omitted, so the message + // tool should pass ctx.user_id through to the channel layer. + let tool = MessageTool::new(Arc::new(ChannelManager::new())); + + let mut ctx = + crate::context::JobContext::with_user("owner-scope", "routine-job", "price alert"); + ctx.metadata = serde_json::json!({ + "notify_channel": "telegram", + }); + + let result = tool + .execute(serde_json::json!({"content": "NEAR price is $5"}), &ctx) + .await; + + assert!(result.is_err()); // safety: test-only assertion + let err = result.unwrap_err().to_string(); + let mentions_missing_target = err.contains("No target specified"); + assert!(!mentions_missing_target); // safety: test-only assertion + let mentions_missing_channel = err.contains("No channel specified"); + assert!(!mentions_missing_channel); // safety: test-only assertion + } + #[tokio::test] async fn message_tool_no_metadata_still_errors() { // When neither conversation context nor metadata is set, should still diff --git a/src/tools/builtin/routine.rs b/src/tools/builtin/routine.rs index 42a771d3ba..347cb4ff07 100644 --- a/src/tools/builtin/routine.rs +++ b/src/tools/builtin/routine.rs @@ -106,7 +106,7 @@ pub(crate) fn routine_create_parameters_schema() -> serde_json::Value { }, "notify_user": { "type": "string", - "description": "User or destination to notify, for example a username or chat ID." + "description": "Optional explicit user or destination to notify, for example a username or chat ID. Omit it to use the configured owner's last-seen target for that channel." }, "timezone": { "type": "string", @@ -387,8 +387,7 @@ impl Tool for RoutineCreateTool { user: params .get("notify_user") .and_then(|v| v.as_str()) - .unwrap_or("default") - .to_string(), + .map(String::from), ..NotifyConfig::default() }, last_run_at: None, diff --git a/src/tools/wasm/wrapper.rs b/src/tools/wasm/wrapper.rs index bceb940169..be089dd83b 100644 --- a/src/tools/wasm/wrapper.rs +++ b/src/tools/wasm/wrapper.rs @@ -841,13 +841,7 @@ impl Tool for WasmToolWrapper { // Pre-resolve host credentials from secrets store (async, before blocking task). // This decrypts the secrets once so the sync http_request() host function // can inject them without needing async access. - // - // BUG FIX: ExtensionManager stores OAuth tokens under user_id "default" - // (hardcoded at construction in app.rs), but this was previously looking - // them up under ctx.user_id — which could be a Telegram user ID, web - // gateway user, etc. — causing credential resolution to silently fail. - // Must match the storage key until per-user credential isolation is added. - let credential_user_id = "default"; + let credential_user_id = &ctx.user_id; let host_credentials = resolve_host_credentials( &self.capabilities, self.secrets_store.as_deref(), @@ -1165,6 +1159,13 @@ async fn resolve_host_credentials( let secret = match store.get_decrypted(user_id, &mapping.secret_name).await { Ok(s) => Some(s), Err(e) => { + tracing::trace!( + user_id = %user_id, + secret_name = %mapping.secret_name, + error = %e, + "No matching host credential resolved for WASM tool in the requested scope" + ); + // If lookup fails and we're not already looking up "default", try "default" as fallback if user_id != "default" { tracing::debug!( @@ -1385,7 +1386,16 @@ fn build_tool_usage_hint(tool_name: &str, schema: &serde_json::Value) -> String #[cfg(test)] mod tests { - use std::sync::Arc; + use std::sync::{Arc, Mutex}; + + use async_trait::async_trait; + use uuid::Uuid; + + use crate::context::JobContext; + use crate::secrets::{ + CreateSecretParams, DecryptedSecret, InMemorySecretsStore, Secret, SecretError, SecretRef, + SecretsStore, + }; use crate::testing::credentials::{ TEST_BEARER_TOKEN_123, TEST_GOOGLE_OAUTH_FRESH, TEST_GOOGLE_OAUTH_LEGACY, @@ -1396,6 +1406,78 @@ mod tests { use crate::tools::wasm::capabilities::Capabilities; use crate::tools::wasm::runtime::{WasmRuntimeConfig, WasmToolRuntime}; + struct RecordingSecretsStore { + inner: InMemorySecretsStore, + get_decrypted_lookups: Mutex>, + } + + impl RecordingSecretsStore { + fn new() -> Self { + Self { + inner: test_secrets_store(), + get_decrypted_lookups: Mutex::new(Vec::new()), + } + } + + fn decrypted_lookups(&self) -> Vec<(String, String)> { + self.get_decrypted_lookups.lock().unwrap().clone() + } + } + + #[async_trait] + impl SecretsStore for RecordingSecretsStore { + async fn create( + &self, + user_id: &str, + params: CreateSecretParams, + ) -> Result { + self.inner.create(user_id, params).await + } + + async fn get(&self, user_id: &str, name: &str) -> Result { + self.inner.get(user_id, name).await + } + + async fn get_decrypted( + &self, + user_id: &str, + name: &str, + ) -> Result { + self.get_decrypted_lookups + .lock() + .unwrap() + .push((user_id.to_string(), name.to_string())); + self.inner.get_decrypted(user_id, name).await + } + + async fn exists(&self, user_id: &str, name: &str) -> Result { + self.inner.exists(user_id, name).await + } + + async fn list(&self, user_id: &str) -> Result, SecretError> { + self.inner.list(user_id).await + } + + async fn delete(&self, user_id: &str, name: &str) -> Result { + self.inner.delete(user_id, name).await + } + + async fn record_usage(&self, secret_id: Uuid) -> Result<(), SecretError> { + self.inner.record_usage(secret_id).await + } + + async fn is_accessible( + &self, + user_id: &str, + secret_name: &str, + allowed_secrets: &[String], + ) -> Result { + self.inner + .is_accessible(user_id, secret_name, allowed_secrets) + .await + } + } + #[test] fn test_wrapper_creation() { // This test verifies the runtime can be created @@ -1691,6 +1773,104 @@ mod tests { ); } + #[tokio::test] + async fn test_resolve_host_credentials_owner_scope_bearer() { + use std::collections::HashMap; + + use crate::secrets::{ + CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore, + }; + use crate::tools::wasm::capabilities::HttpCapability; + use crate::tools::wasm::wrapper::resolve_host_credentials; + + let store = test_secrets_store(); + let ctx = JobContext::with_user("owner-scope", "owner-scope test", "owner-scope test"); + + store + .create( + &ctx.user_id, + CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_TOKEN), + ) + .await + .unwrap(); + + let mut credentials = HashMap::new(); + credentials.insert( + "google_oauth_token".to_string(), + CredentialMapping { + secret_name: "google_oauth_token".to_string(), + location: CredentialLocation::AuthorizationBearer, + host_patterns: vec!["www.googleapis.com".to_string()], + }, + ); + + let caps = Capabilities { + http: Some(HttpCapability { + credentials, + ..Default::default() + }), + ..Default::default() + }; + + let result = resolve_host_credentials(&caps, Some(&store), &ctx.user_id, None).await; + assert_eq!(result.len(), 1); + assert_eq!( + result[0].headers.get("Authorization"), + Some(&format!("Bearer {TEST_GOOGLE_OAUTH_TOKEN}")) + ); + } + + #[tokio::test] + async fn test_execute_resolves_host_credentials_from_owner_scope_context() { + use std::collections::HashMap; + + use crate::secrets::{CredentialLocation, CredentialMapping}; + use crate::tools::wasm::capabilities::HttpCapability; + + let runtime = Arc::new(WasmToolRuntime::new(WasmRuntimeConfig::for_testing()).unwrap()); + let prepared = runtime + .prepare("search", b"\0asm\x0d\0\x01\0", None) + .await + .unwrap(); + let store = Arc::new(RecordingSecretsStore::new()); + let ctx = JobContext::with_user("owner-scope", "owner-scope test", "owner-scope test"); + + store + .create( + &ctx.user_id, + CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_TOKEN), + ) + .await + .unwrap(); + + let mut credentials = HashMap::new(); + credentials.insert( + "google_oauth_token".to_string(), + CredentialMapping { + secret_name: "google_oauth_token".to_string(), + location: CredentialLocation::AuthorizationBearer, + host_patterns: vec!["www.googleapis.com".to_string()], + }, + ); + + let caps = Capabilities { + http: Some(HttpCapability { + credentials, + ..Default::default() + }), + ..Default::default() + }; + + let wrapper = super::WasmToolWrapper::new(Arc::clone(&runtime), prepared, caps) + .with_secrets_store(store.clone()); + let result = wrapper.execute(serde_json::json!({}), &ctx).await; + assert!(result.is_err()); + + let lookups = store.decrypted_lookups(); + assert!(lookups.contains(&("owner-scope".to_string(), "google_oauth_token".to_string()))); + assert!(!lookups.contains(&("default".to_string(), "google_oauth_token".to_string()))); + } + #[tokio::test] async fn test_resolve_host_credentials_missing_secret() { use std::collections::HashMap; diff --git a/src/transcription/chat_completions.rs b/src/transcription/chat_completions.rs new file mode 100644 index 0000000000..e23818aaa2 --- /dev/null +++ b/src/transcription/chat_completions.rs @@ -0,0 +1,179 @@ +//! Chat Completions-based transcription provider. +//! +//! Uses the `/v1/chat/completions` endpoint with `input_audio` content type +//! to transcribe audio. Compatible with OpenRouter, OpenAI GPT-4o-audio, and +//! any provider that supports audio input via the Chat Completions API. + +use async_trait::async_trait; +use base64::Engine; +use secrecy::{ExposeSecret, SecretString}; + +use super::{AudioFormat, TranscriptionError, TranscriptionProvider}; + +/// Transcription provider that sends audio via the Chat Completions API. +/// +/// Unlike the Whisper provider (which uses `/v1/audio/transcriptions` with +/// multipart upload), this provider sends base64-encoded audio as an +/// `input_audio` content part in a chat message, enabling use with +/// OpenRouter and other providers that only expose audio through the +/// Chat Completions API. +pub struct ChatCompletionsTranscriptionProvider { + client: reqwest::Client, + api_key: SecretString, + model: String, + base_url: String, +} + +impl ChatCompletionsTranscriptionProvider { + /// Create a new provider with the given API key. + pub fn new(api_key: SecretString) -> Self { + Self { + client: match reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .build() + { + Ok(c) => c, + Err(e) => { + tracing::error!( + "Failed to build HTTP client with timeout, falling back to default: {e}" + ); + reqwest::Client::default() + } + }, + api_key, + model: "google/gemini-2.0-flash-001".to_string(), + base_url: "https://openrouter.ai/api".to_string(), + } + } + + /// Override the base URL. + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into().trim_end_matches('/').to_string(); + self + } + + /// Override the model name. + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self + } +} + +/// Map [`AudioFormat`] to the format string expected by the Chat Completions API. +fn audio_format_str(format: AudioFormat) -> &'static str { + match format { + AudioFormat::Ogg => "ogg", + AudioFormat::Mp3 => "mp3", + AudioFormat::Mp4 => "mp4", + AudioFormat::Wav => "wav", + AudioFormat::Webm => "webm", + AudioFormat::Flac => "flac", + AudioFormat::M4a => "m4a", + } +} + +#[async_trait] +impl TranscriptionProvider for ChatCompletionsTranscriptionProvider { + async fn transcribe( + &self, + audio_data: &[u8], + format: AudioFormat, + ) -> Result { + if audio_data.is_empty() { + return Err(TranscriptionError::EmptyAudio); + } + + let b64 = base64::engine::general_purpose::STANDARD.encode(audio_data); + + let body = serde_json::json!({ + "model": self.model, + "messages": [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "Transcribe this audio. Return only the transcript text, nothing else." + }, + { + "type": "input_audio", + "input_audio": { + "data": b64, + "format": audio_format_str(format) + } + } + ] + }] + }); + + let url = format!("{}/v1/chat/completions", self.base_url); + + let response = self + .client + .post(&url) + .header( + "Authorization", + format!("Bearer {}", self.api_key.expose_secret()), + ) + .json(&body) + .send() + .await + .map_err(|e| TranscriptionError::RequestFailed(e.to_string()))?; + + let status = response.status(); + if !status.is_success() { + let body = response + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + return Err(TranscriptionError::RequestFailed(format!( + "HTTP {}: {}", + status, body + ))); + } + + let json: serde_json::Value = response + .json() + .await + .map_err(|e| TranscriptionError::RequestFailed(e.to_string()))?; + + // Extract text from the standard Chat Completions response format: + // { "choices": [{ "message": { "content": "..." } }] } + let text = json + .get("choices") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("message")) + .and_then(|m| m.get("content")) + .and_then(|c| c.as_str()) + .ok_or_else(|| { + TranscriptionError::RequestFailed( + "unexpected response format: missing choices[0].message.content".to_string(), + ) + })?; + + Ok(text.trim().to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn audio_format_str_maps_all_variants() { + assert_eq!(audio_format_str(AudioFormat::Ogg), "ogg"); + assert_eq!(audio_format_str(AudioFormat::Mp3), "mp3"); + assert_eq!(audio_format_str(AudioFormat::Mp4), "mp4"); + assert_eq!(audio_format_str(AudioFormat::Wav), "wav"); + assert_eq!(audio_format_str(AudioFormat::Webm), "webm"); + assert_eq!(audio_format_str(AudioFormat::Flac), "flac"); + assert_eq!(audio_format_str(AudioFormat::M4a), "m4a"); + } + + #[tokio::test] + async fn rejects_empty_audio() { + let provider = + ChatCompletionsTranscriptionProvider::new(SecretString::from("test-key".to_string())); + let result = provider.transcribe(&[], AudioFormat::Ogg).await; + assert!(matches!(result, Err(TranscriptionError::EmptyAudio))); + } +} diff --git a/src/transcription/mod.rs b/src/transcription/mod.rs index d0a7d31c01..ab2e43f94b 100644 --- a/src/transcription/mod.rs +++ b/src/transcription/mod.rs @@ -4,8 +4,10 @@ //! backends and a [`TranscriptionMiddleware`] that detects audio attachments //! on incoming messages and replaces them with transcribed text. +mod chat_completions; mod openai; +pub use self::chat_completions::ChatCompletionsTranscriptionProvider; pub use self::openai::OpenAiWhisperProvider; use async_trait::async_trait; diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index b19c77af1a..06c7da0384 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -15,7 +15,13 @@ import pytest -from helpers import AUTH_TOKEN, wait_for_port_line, wait_for_ready +from helpers import ( + AUTH_TOKEN, + HTTP_WEBHOOK_SECRET, + OWNER_SCOPE_ID, + wait_for_port_line, + wait_for_ready, +) # Project root (two levels up from tests/e2e/) ROOT = Path(__file__).resolve().parent.parent.parent @@ -92,6 +98,21 @@ def _find_free_port() -> int: return s.getsockname()[1] +def _reserve_loopback_sockets(count: int) -> list[socket.socket]: + """Bind loopback sockets and keep them open until the server starts.""" + sockets: list[socket.socket] = [] + try: + while len(sockets) < count: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("127.0.0.1", 0)) + sockets.append(sock) + return sockets + except Exception: + for sock in sockets: + sock.close() + raise + + @pytest.fixture(scope="session") def ironclaw_binary(): """Ensure ironclaw binary is built. Returns the binary path.""" @@ -108,6 +129,21 @@ def ironclaw_binary(): return str(binary) +@pytest.fixture(scope="session") +def server_ports(): + """Reserve dynamic ports for the gateway and HTTP webhook channel.""" + reserved = _reserve_loopback_sockets(2) + try: + yield { + "gateway": reserved[0].getsockname()[1], + "http": reserved[1].getsockname()[1], + "sockets": reserved, + } + finally: + for sock in reserved: + sock.close() + + @pytest.fixture(scope="session") async def mock_llm_server(): """Start the mock LLM server. Yields the base URL.""" @@ -177,10 +213,19 @@ def _wasm_build_symlinks(): @pytest.fixture(scope="session") -async def ironclaw_server(ironclaw_binary, mock_llm_server, wasm_tools_dir): +async def ironclaw_server( + ironclaw_binary, + mock_llm_server, + wasm_tools_dir, + server_ports, +): """Start the ironclaw gateway. Yields the base URL.""" - gateway_port = _find_free_port() home_dir = _HOME_TMPDIR.name + gateway_port = server_ports["gateway"] + http_port = server_ports["http"] + for sock in server_ports["sockets"]: + if sock.fileno() != -1: + sock.close() env = { # Minimal env: PATH for process spawning, HOME for Rust/cargo defaults "PATH": os.environ.get("PATH", "/usr/bin:/bin"), @@ -188,11 +233,15 @@ async def ironclaw_server(ironclaw_binary, mock_llm_server, wasm_tools_dir): "IRONCLAW_BASE_DIR": os.path.join(home_dir, ".ironclaw"), "RUST_LOG": "ironclaw=info", "RUST_BACKTRACE": "1", + "IRONCLAW_OWNER_ID": OWNER_SCOPE_ID, "GATEWAY_ENABLED": "true", "GATEWAY_HOST": "127.0.0.1", "GATEWAY_PORT": str(gateway_port), "GATEWAY_AUTH_TOKEN": AUTH_TOKEN, - "GATEWAY_USER_ID": "e2e-tester", + "GATEWAY_USER_ID": "e2e-web-sender", + "HTTP_HOST": "127.0.0.1", + "HTTP_PORT": str(http_port), + "HTTP_WEBHOOK_SECRET": HTTP_WEBHOOK_SECRET, "CLI_ENABLED": "false", "LLM_BACKEND": "openai_compatible", "LLM_BASE_URL": mock_llm_server, @@ -262,15 +311,22 @@ async def ironclaw_server(ironclaw_binary, mock_llm_server, wasm_tools_dir): @pytest.fixture(scope="session") -async def ironclaw_server_with_webhook_secret(ironclaw_binary, mock_llm_server, wasm_tools_dir): - """Start ironclaw with HTTP_WEBHOOK_SECRET configured for webhook tests. +async def http_channel_server(ironclaw_server, server_ports): + """HTTP webhook channel base URL.""" + base_url = f"http://127.0.0.1:{server_ports['http']}" + await wait_for_ready(f"{base_url}/health", timeout=30) + return base_url - Yields a dict with: - - 'url': base URL of the gateway - - 'secret': the webhook secret value - """ + +@pytest.fixture(scope="session") +async def http_channel_server_without_secret( + ironclaw_binary, + mock_llm_server, + wasm_tools_dir, +): + """Start the HTTP webhook channel without a configured secret.""" gateway_port = _find_free_port() - webhook_secret = "test-webhook-secret-e2e-12345" + http_port = _find_free_port() env = { # Minimal env: PATH for process spawning, HOME for Rust/cargo defaults "PATH": os.environ.get("PATH", "/usr/bin:/bin"), @@ -282,13 +338,14 @@ async def ironclaw_server_with_webhook_secret(ironclaw_binary, mock_llm_server, "GATEWAY_PORT": str(gateway_port), "GATEWAY_AUTH_TOKEN": AUTH_TOKEN, "GATEWAY_USER_ID": "e2e-tester", - "HTTP_WEBHOOK_SECRET": webhook_secret, + "HTTP_HOST": "127.0.0.1", + "HTTP_PORT": str(http_port), "CLI_ENABLED": "false", "LLM_BACKEND": "openai_compatible", "LLM_BASE_URL": mock_llm_server, "LLM_MODEL": "mock-model", "DATABASE_BACKEND": "libsql", - "LIBSQL_PATH": os.path.join(_DB_TMPDIR.name, "e2e-webhook.db"), + "LIBSQL_PATH": os.path.join(_DB_TMPDIR.name, "e2e-webhook-no-secret.db"), "SANDBOX_ENABLED": "false", "SKILLS_ENABLED": "true", "ROUTINES_ENABLED": "false", @@ -318,13 +375,12 @@ async def ironclaw_server_with_webhook_secret(ironclaw_binary, mock_llm_server, stderr=asyncio.subprocess.PIPE, env=env, ) - base_url = f"http://127.0.0.1:{gateway_port}" + gateway_url = f"http://127.0.0.1:{gateway_port}" + http_base_url = f"http://127.0.0.1:{http_port}" try: - await wait_for_ready(f"{base_url}/api/health", timeout=60) - yield { - "url": base_url, - "secret": webhook_secret, - } + await wait_for_ready(f"{gateway_url}/api/health", timeout=60) + await wait_for_ready(f"{http_base_url}/health", timeout=30) + yield http_base_url except TimeoutError: # Dump stderr so CI logs show why the server failed to start returncode = proc.returncode @@ -337,7 +393,8 @@ async def ironclaw_server_with_webhook_secret(ironclaw_binary, mock_llm_server, stderr_text = stderr_bytes.decode("utf-8", errors="replace") proc.kill() pytest.fail( - f"ironclaw server with webhook secret failed to start on port {gateway_port} " + f"ironclaw server without webhook secret failed to start on ports " + f"gateway={gateway_port}, http={http_port} " f"(returncode={returncode}).\nstderr:\n{stderr_text}" ) finally: diff --git a/tests/e2e/helpers.py b/tests/e2e/helpers.py index 629205a147..a0c498e575 100644 --- a/tests/e2e/helpers.py +++ b/tests/e2e/helpers.py @@ -1,6 +1,8 @@ """Shared helpers for E2E tests.""" import asyncio +import hashlib +import hmac import re import time @@ -95,12 +97,21 @@ "toast_success": ".toast.toast-success", "toast_error": ".toast.toast-error", "toast_info": ".toast.toast-info", + # Jobs / routines + "jobs_tbody": "#jobs-tbody", + "job_row": "#jobs-tbody .job-row", + "jobs_empty": "#jobs-empty", + "routines_tbody": "#routines-tbody", + "routine_row": "#routines-tbody .routine-row", + "routines_empty": "#routines-empty", } TABS = ["chat", "memory", "jobs", "routines", "extensions", "skills"] # Auth token used across all tests AUTH_TOKEN = "e2e-test-token" +OWNER_SCOPE_ID = "e2e-owner-scope" +HTTP_WEBHOOK_SECRET = "e2e-http-webhook-secret" async def wait_for_ready(url: str, *, timeout: float = 60, interval: float = 0.5): @@ -162,3 +173,16 @@ async def api_post(base_url: str, path: str, **kwargs) -> httpx.Response: timeout=kwargs.pop("timeout", 10), **kwargs, ) + + +def signed_http_webhook_headers(body: bytes) -> dict[str, str]: + """Return headers for the owner-scoped HTTP webhook channel.""" + digest = hmac.new( + HTTP_WEBHOOK_SECRET.encode("utf-8"), + body, + hashlib.sha256, + ).hexdigest() + return { + "Content-Type": "application/json", + "X-Hub-Signature-256": f"sha256={digest}", + } diff --git a/tests/e2e/ironclaw_e2e.egg-info/SOURCES.txt b/tests/e2e/ironclaw_e2e.egg-info/SOURCES.txt index 7f0113823f..c2784f643b 100644 --- a/tests/e2e/ironclaw_e2e.egg-info/SOURCES.txt +++ b/tests/e2e/ironclaw_e2e.egg-info/SOURCES.txt @@ -12,11 +12,17 @@ scenarios/test_csp.py scenarios/test_extension_oauth.py scenarios/test_extensions.py scenarios/test_html_injection.py +scenarios/test_mcp_auth_flow.py scenarios/test_oauth_credential_fallback.py +scenarios/test_owner_scope.py scenarios/test_pairing.py +scenarios/test_routine_event_batch.py scenarios/test_routine_oauth_credential_injection.py scenarios/test_skills.py scenarios/test_sse_reconnect.py +scenarios/test_telegram_hot_activation.py +scenarios/test_telegram_token_validation.py scenarios/test_tool_approval.py scenarios/test_tool_execution.py -scenarios/test_wasm_lifecycle.py \ No newline at end of file +scenarios/test_wasm_lifecycle.py +scenarios/test_webhook.py \ No newline at end of file diff --git a/tests/e2e/mock_llm.py b/tests/e2e/mock_llm.py index 175accf520..b091fc1739 100644 --- a/tests/e2e/mock_llm.py +++ b/tests/e2e/mock_llm.py @@ -26,6 +26,59 @@ TOOL_CALL_PATTERNS = [ (re.compile(r"echo (.+)", re.IGNORECASE), "echo", lambda m: {"message": m.group(1)}), (re.compile(r"what time|current time", re.IGNORECASE), "time", lambda _: {"operation": "now"}), + ( + re.compile( + r"create lightweight owner routine (?P[a-z0-9][a-z0-9_-]*)", + re.IGNORECASE, + ), + "routine_create", + lambda m: { + "name": m.group("name"), + "description": f"Owner-scope routine {m.group('name')}", + "trigger_type": "manual", + "prompt": f"Confirm that {m.group('name')} executed.", + "action_type": "lightweight", + "use_tools": False, + }, + ), + ( + re.compile( + r"create full[- ]job owner routine (?P[a-z0-9][a-z0-9_-]*)", + re.IGNORECASE, + ), + "routine_create", + lambda m: { + "name": m.group("name"), + "description": f"Owner-scope full-job routine {m.group('name')}", + "trigger_type": "manual", + "prompt": f"Complete the routine job for {m.group('name')}.", + "action_type": "full_job", + }, + ), + ( + re.compile( + r"create event routine (?P[a-z0-9][a-z0-9_-]*) " + r"channel (?P[a-z0-9_-]+) pattern (?P[a-z0-9_|-]+)", + re.IGNORECASE, + ), + "routine_create", + lambda m: { + "name": m.group("name"), + "description": f"Event routine {m.group('name')}", + "trigger_type": "event", + "event_channel": None if m.group("channel").lower() == "any" else m.group("channel"), + "event_pattern": m.group("pattern"), + "prompt": f"Acknowledge that {m.group('name')} fired.", + "action_type": "lightweight", + "use_tools": False, + "cooldown_secs": 0, + }, + ), + ( + re.compile(r"list owner routines", re.IGNORECASE), + "routine_list", + lambda _: {}, + ), ] diff --git a/tests/e2e/scenarios/test_owner_scope.py b/tests/e2e/scenarios/test_owner_scope.py new file mode 100644 index 0000000000..56f3b01ec7 --- /dev/null +++ b/tests/e2e/scenarios/test_owner_scope.py @@ -0,0 +1,226 @@ +"""Owner-scope end-to-end scenarios. + +These tests exercise the explicit owner model across: +- the web gateway chat UI +- the owner-scoped HTTP webhook channel +- routine tools / routines tab +- job creation via routine execution / jobs tab +""" + +import asyncio +import json +import uuid + +import httpx + +from helpers import SEL, AUTH_TOKEN, signed_http_webhook_headers + + +async def _send_and_get_response( + page, + message: str, + *, + expected_fragment: str, + timeout: int = 30000, +) -> str: + """Send a chat message and return the newest assistant response text.""" + chat_input = page.locator(SEL["chat_input"]) + await chat_input.wait_for(state="visible", timeout=5000) + + assistant_sel = SEL["message_assistant"] + before_count = await page.locator(assistant_sel).count() + + await chat_input.fill(message) + await chat_input.press("Enter") + + expected = before_count + 1 + await page.wait_for_function( + """({ assistantSelector, expectedCount, expectedFragment }) => { + const messages = document.querySelectorAll(assistantSelector); + if (messages.length < expectedCount) return false; + const text = (messages[messages.length - 1].innerText || '').trim().toLowerCase(); + return text.includes(expectedFragment.toLowerCase()); + }""", + arg={ + "assistantSelector": assistant_sel, + "expectedCount": expected, + "expectedFragment": expected_fragment, + }, + timeout=timeout, + ) + + return await page.locator(assistant_sel).last.inner_text() + + +async def _post_http_webhook( + http_channel_server: str, + *, + content: str, + sender_id: str, + thread_id: str, +) -> str: + """Send a signed request to the owner-scoped HTTP webhook channel.""" + payload = { + "user_id": sender_id, + "thread_id": thread_id, + "content": content, + "wait_for_response": True, + } + body = json.dumps(payload).encode("utf-8") + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{http_channel_server}/webhook", + content=body, + headers=signed_http_webhook_headers(body), + timeout=90, + ) + + assert response.status_code == 200, ( + f"HTTP webhook failed: {response.status_code} {response.text[:400]}" + ) + data = response.json() + assert data["status"] == "accepted", f"Unexpected webhook response: {data}" + assert data["response"], f"Expected synchronous response body, got: {data}" + return data["response"] + + +async def _open_tab(page, tab: str) -> None: + btn = page.locator(SEL["tab_button"].format(tab=tab)) + await btn.click() + await page.locator(SEL["tab_panel"].format(tab=tab)).wait_for( + state="visible", + timeout=5000, + ) + + +async def _wait_for_routine(base_url: str, name: str, timeout: float = 20.0) -> dict: + """Poll the routines API until the named routine exists.""" + async with httpx.AsyncClient() as client: + for _ in range(int(timeout * 2)): + response = await client.get( + f"{base_url}/api/routines", + headers={"Authorization": f"Bearer {AUTH_TOKEN}"}, + timeout=10, + ) + response.raise_for_status() + routines = response.json()["routines"] + for routine in routines: + if routine["name"] == name: + return routine + await _poll_sleep() + raise AssertionError(f"Routine '{name}' was not created within {timeout}s") + + +async def _wait_for_job(base_url: str, title: str, timeout: float = 30.0) -> dict: + """Poll the jobs API until the named job exists.""" + async with httpx.AsyncClient() as client: + for _ in range(int(timeout * 2)): + response = await client.get( + f"{base_url}/api/jobs", + headers={"Authorization": f"Bearer {AUTH_TOKEN}"}, + timeout=10, + ) + response.raise_for_status() + jobs = response.json()["jobs"] + for job in jobs: + if job["title"] == title: + return job + await _poll_sleep() + raise AssertionError(f"Job '{title}' was not created within {timeout}s") + + +async def _poll_sleep() -> None: + """Small shared backoff for API polling loops.""" + await asyncio.sleep(0.5) + + +async def test_http_channel_created_routine_is_visible_in_web_routines_tab( + page, + ironclaw_server, + http_channel_server, +): + """A routine created from the HTTP channel is visible in the web owner UI.""" + routine_name = f"owner-http-{uuid.uuid4().hex[:8]}" + + response_text = await _post_http_webhook( + http_channel_server, + content=f"create lightweight owner routine {routine_name}", + sender_id="external-sender-alpha", + thread_id="http-owner-routine-thread", + ) + assert routine_name in response_text + + await _wait_for_routine(ironclaw_server, routine_name) + + await _open_tab(page, "routines") + await page.locator(SEL["routine_row"]).filter(has_text=routine_name).first.wait_for( + state="visible", + timeout=15000, + ) + + +async def test_web_created_routine_is_listed_from_http_channel_across_senders( + page, + ironclaw_server, + http_channel_server, +): + """Routines created in web chat remain owner-global across HTTP senders/threads.""" + routine_name = f"owner-web-{uuid.uuid4().hex[:8]}" + + assistant_text = await _send_and_get_response( + page, + f"create lightweight owner routine {routine_name}", + expected_fragment=routine_name, + ) + assert routine_name in assistant_text + + await _wait_for_routine(ironclaw_server, routine_name) + + first_sender_text = await _post_http_webhook( + http_channel_server, + content="list owner routines", + sender_id="http-sender-one", + thread_id="owner-list-thread-a", + ) + second_sender_text = await _post_http_webhook( + http_channel_server, + content="list owner routines", + sender_id="http-sender-two", + thread_id="owner-list-thread-b", + ) + + assert routine_name in first_sender_text, first_sender_text + assert routine_name in second_sender_text, second_sender_text + + +async def test_http_created_full_job_routine_can_be_run_from_web_and_shows_in_jobs( + page, + ironclaw_server, + http_channel_server, +): + """A full-job routine created via HTTP can be run from the web UI and create a job.""" + routine_name = f"owner-job-{uuid.uuid4().hex[:8]}" + + response_text = await _post_http_webhook( + http_channel_server, + content=f"create full-job owner routine {routine_name}", + sender_id="http-job-sender", + thread_id="owner-job-thread", + ) + assert routine_name in response_text + + await _wait_for_routine(ironclaw_server, routine_name) + + await _open_tab(page, "routines") + routine_row = page.locator(SEL["routine_row"]).filter(has_text=routine_name).first + await routine_row.wait_for(state="visible", timeout=15000) + await routine_row.locator('button[data-action="trigger-routine"]').click() + + await _wait_for_job(ironclaw_server, routine_name, timeout=45.0) + + await _open_tab(page, "jobs") + await page.locator(SEL["job_row"]).filter(has_text=routine_name).first.wait_for( + state="visible", + timeout=20000, + ) diff --git a/tests/e2e/scenarios/test_routine_event_batch.py b/tests/e2e/scenarios/test_routine_event_batch.py index d8c59e6d94..7da78a15c1 100644 --- a/tests/e2e/scenarios/test_routine_event_batch.py +++ b/tests/e2e/scenarios/test_routine_event_batch.py @@ -1,534 +1,317 @@ -""" -E2E tests for event-triggered routines with batch loading. - -These tests verify that the N+1 query fix correctly: -1. Fires event-triggered routines on matching messages -2. Enforces concurrent limits via batch-loaded counts -3. Maintains performance with multiple simultaneous triggers -4. Works correctly through the full UI and agent loop - -Playwright-based UI tests + SSE verification. -""" +"""E2E tests for event-triggered routines over the HTTP channel.""" import asyncio import json -import pytest -from datetime import datetime, timedelta -from typing import List, Dict, Any - -from playwright.async_api import async_playwright, Page, Browser, BrowserContext - - -@pytest.fixture -async def browser_and_context(): - """Create a Playwright browser and context for testing.""" - async with async_playwright() as p: - browser = await p.chromium.launch(headless=True) - context = await browser.new_context() - yield browser, context - await context.close() - await browser.close() - - -class EventTriggerHelper: - """Helper methods for event trigger testing.""" - - def __init__(self, page: Page): - self.page = page - - async def navigate_to_routines(self): - """Navigate to the routines page.""" - await self.page.goto("http://localhost:8000/routines") - await self.page.wait_for_load_state("networkidle") - - async def create_event_routine( - self, - name: str, - trigger_regex: str, - channel: str = "slack", - max_concurrent: int = 1, - ) -> str: - """ - Create an event-triggered routine via UI. - Returns the routine ID. - """ - await self.navigate_to_routines() - - # Click "New Routine" button - await self.page.click('button:has-text("New Routine")') - await self.page.wait_for_selector('input[name="routine_name"]') - - # Fill routine details - await self.page.fill('input[name="routine_name"]', name) - await self.page.fill( - 'textarea[name="routine_description"]', - f"Test routine: {name}", - ) - - # Select "Event Trigger" type - await self.page.click('label:has-text("Event Trigger")') - await self.page.wait_for_selector('input[name="trigger_regex"]') - - # Fill trigger details - await self.page.fill('input[name="trigger_regex"]', trigger_regex) - await self.page.select_option('select[name="trigger_channel"]', channel) - - # Set guardrails - await self.page.fill('input[name="max_concurrent"]', str(max_concurrent)) - - # Select lightweight action - await self.page.click('label:has-text("Lightweight")') - await self.page.fill( - 'textarea[name="lightweight_prompt"]', - "Acknowledge the message and confirm trigger worked.", - ) - - # Save routine - await self.page.click('button:has-text("Save Routine")') - await self.page.wait_for_selector('text=Routine created successfully') - - # Extract routine ID from success message or URL - routine_id = await self.page.locator('data-testid=routine-id').text_content() - return routine_id.strip() if routine_id else None - - async def create_multiple_routines( - self, base_name: str, count: int, trigger_regex: str = None - ) -> List[str]: - """Create multiple event-triggered routines.""" - routine_ids = [] - for i in range(count): - name = f"{base_name}_{i}" - regex = trigger_regex or f"({i}|{base_name})" - routine_id = await self.create_event_routine(name, regex) - routine_ids.append(routine_id) - await asyncio.sleep(0.1) # Small delay between creations - return routine_ids - - async def send_chat_message(self, message: str) -> List[str]: - """ - Send a chat message and return SSE events received. - Captures all routine firing events. - """ - await self.page.goto("http://localhost:8000/chat") - await self.page.wait_for_selector('input[placeholder*="message"]', timeout=5000) - - # Collect SSE events - sse_events = [] - - async def capture_sse(response): - """Intercept SSE events.""" - if "event-stream" in response.headers.get("content-type", ""): - text = await response.text() - for line in text.split("\n"): - if line.startswith("data:"): - try: - event = json.loads(line[5:]) - sse_events.append(event) - except json.JSONDecodeError: - pass - - self.page.on("response", capture_sse) - - # Send message - await self.page.fill('input[placeholder*="message"]', message) - await self.page.press('input[placeholder*="message"]', "Enter") - - # Wait for response - await self.page.wait_for_selector('text=Message processed', timeout=10000) - await asyncio.sleep(0.5) # Allow time for SSE events - - self.page.remove_listener("response", capture_sse) - return sse_events - - async def get_routine_execution_log(self, routine_id: str) -> List[Dict]: - """Get execution log entries for a routine.""" - await self.page.goto(f"http://localhost:8000/routines/{routine_id}/executions") - await self.page.wait_for_load_state("networkidle") - - # Extract log entries from table - rows = await self.page.locator("tbody tr").all() - executions = [] - - for row in rows: - cells = await row.locator("td").all() - if len(cells) >= 3: - execution = { - "timestamp": await cells[0].text_content(), - "status": await cells[1].text_content(), - "details": await cells[2].text_content(), - } - executions.append(execution) - - return executions - - async def check_database_queries_in_logs( - self, max_queries_expected: int = 1 - ) -> int: - """Check debug logs for database query count.""" - await self.page.goto("http://localhost:8000/debug/logs?filter=database") - await self.page.wait_for_load_state("networkidle") - - # Count batch queries - log_lines = await self.page.locator("tr:has-text('batch')").all() - batch_count = len(log_lines) - - # Count individual COUNT queries (should be 0 after fix) - count_queries = await self.page.locator("tr:has-text('COUNT')").all() - count_query_count = len(count_queries) - - return batch_count, count_query_count - - -# ============================================================================= -# Tests -# ============================================================================= - - -@pytest.mark.asyncio -async def test_create_event_trigger_routine(browser_and_context): - """Test creating an event-triggered routine via UI.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - routine_id = await helper.create_event_routine( - name="Test Trigger", - trigger_regex="test|demo", - channel="slack", - max_concurrent=1, - ) - - assert routine_id is not None, "Routine ID should be returned" - assert len(routine_id) > 0, "Routine ID should not be empty" - - finally: - await page.close() - - -@pytest.mark.asyncio -async def test_event_trigger_fires_on_matching_message(browser_and_context): - """Test that event-triggered routine fires when message matches.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create routine - routine_id = await helper.create_event_routine( - name="Alert Handler", - trigger_regex="urgent|critical|alert", - channel="slack", - ) +import uuid - # Send matching message - sse_events = await helper.send_chat_message("URGENT: Server down!") +import httpx +import pytest - # Verify routine fired (look for event in SSE stream) - routine_fired = any( - event.get("type") == "routine_fired" and event.get("routine_id") == routine_id - for event in sse_events +from helpers import AUTH_TOKEN, SEL, signed_http_webhook_headers + + +async def _send_chat_message(page, message: str) -> None: + """Send a chat message and wait for the assistant turn to appear.""" + chat_input = page.locator(SEL["chat_input"]) + await chat_input.wait_for(state="visible", timeout=5000) + assistant_messages = page.locator(SEL["message_assistant"]) + before_count = await assistant_messages.count() + + await chat_input.fill(message) + await chat_input.press("Enter") + + await page.wait_for_function( + """({ selector, expectedCount }) => { + return document.querySelectorAll(selector).length >= expectedCount; + }""", + arg={ + "selector": SEL["message_assistant"], + "expectedCount": before_count + 1, + }, + timeout=30000, + ) + + +async def _create_event_routine( + page, + base_url: str, + *, + name: str, + pattern: str, + channel: str = "http", +) -> dict: + """Create an event routine through chat and return its API record.""" + await _send_chat_message( + page, + f"create event routine {name} channel {channel} pattern {pattern}", + ) + return await _wait_for_routine(base_url, name) + + +async def _post_http_message( + http_channel_server: str, + *, + content: str, + sender_id: str | None = None, + thread_id: str | None = None, +) -> dict: + """Send a signed HTTP-channel message and return the JSON body.""" + payload = { + "user_id": sender_id or f"sender-{uuid.uuid4().hex[:8]}", + "thread_id": thread_id or f"thread-{uuid.uuid4().hex[:8]}", + "content": content, + "wait_for_response": True, + } + body = json.dumps(payload).encode("utf-8") + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{http_channel_server}/webhook", + content=body, + headers=signed_http_webhook_headers(body), + timeout=90, ) - assert routine_fired, "Routine should fire on matching message" - - # Check execution log - executions = await helper.get_routine_execution_log(routine_id) - assert len(executions) > 0, "Execution should be logged" - assert "success" in executions[0]["status"].lower() - - finally: - await page.close() - -@pytest.mark.asyncio -async def test_event_trigger_skips_non_matching_message(browser_and_context): - """Test that event-triggered routine skips when message doesn't match.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create routine - routine_id = await helper.create_event_routine( - name="Alert Handler", - trigger_regex="urgent|critical|alert", - channel="slack", - ) + assert response.status_code == 200, ( + f"HTTP webhook failed: {response.status_code} {response.text[:400]}" + ) + return response.json() - # Send non-matching message - sse_events = await helper.send_chat_message("Hello, how are you?") - # Verify routine did NOT fire - routine_fired = any( - event.get("type") == "routine_fired" and event.get("routine_id") == routine_id - for event in sse_events +async def _wait_for_routine(base_url: str, name: str, timeout: float = 20.0) -> dict: + """Poll the routines API until the named routine exists.""" + async with httpx.AsyncClient() as client: + for _ in range(int(timeout * 2)): + response = await client.get( + f"{base_url}/api/routines", + headers={"Authorization": f"Bearer {AUTH_TOKEN}"}, + timeout=10, + ) + response.raise_for_status() + for routine in response.json()["routines"]: + if routine["name"] == name: + return routine + await asyncio.sleep(0.5) + raise AssertionError(f"Routine '{name}' was not created within {timeout}s") + + +async def _get_routine_runs(base_url: str, routine_id: str) -> list[dict]: + """Fetch recent routine runs from the web API.""" + async with httpx.AsyncClient() as client: + response = await client.get( + f"{base_url}/api/routines/{routine_id}/runs", + headers={"Authorization": f"Bearer {AUTH_TOKEN}"}, + timeout=10, ) - assert not routine_fired, "Routine should not fire on non-matching message" - - finally: - await page.close() + response.raise_for_status() + return response.json()["runs"] + + +async def _wait_for_run_count( + base_url: str, + routine_id: str, + *, + expected_at_least: int, + timeout: float = 20.0, +) -> list[dict]: + """Poll until the routine has at least the expected run count.""" + for _ in range(int(timeout * 2)): + runs = await _get_routine_runs(base_url, routine_id) + if len(runs) >= expected_at_least: + return runs + await asyncio.sleep(0.5) + raise AssertionError( + f"Routine '{routine_id}' did not reach {expected_at_least} runs within {timeout}s" + ) + + +async def _wait_for_completed_run( + base_url: str, + routine_id: str, + *, + timeout: float = 30.0, +) -> dict: + """Poll until the newest run is no longer marked running.""" + for _ in range(int(timeout * 2)): + runs = await _get_routine_runs(base_url, routine_id) + if runs and runs[0]["status"].lower() != "running": + return runs[0] + await asyncio.sleep(0.5) + raise AssertionError(f"Routine '{routine_id}' did not complete within {timeout}s") @pytest.mark.asyncio -async def test_multiple_routines_fire_on_matching_message(browser_and_context): - """Test that multiple event-triggered routines fire on same message.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create 3 overlapping routines - routine_ids = await helper.create_multiple_routines( - base_name="Handler", count=3, trigger_regex="alert|warning|error" - ) +async def test_create_event_trigger_routine(page, ironclaw_server): + """Event routines can be created through the supported chat flow.""" + name = f"evt-{uuid.uuid4().hex[:8]}" + routine = await _create_event_routine( + page, + ironclaw_server, + name=name, + pattern="test|demo", + ) - # Send matching message - sse_events = await helper.send_chat_message("ERROR: Database connection failed") - - # Verify all 3 routines fired - fired_count = sum( - 1 - for event in sse_events - if event.get("type") == "routine_fired" and event.get("routine_id") in routine_ids - ) - - assert ( - fired_count >= 3 - ), f"Expected all 3 routines to fire, got {fired_count}" - - finally: - await page.close() + assert routine["id"] + assert routine["trigger_type"] == "event" + assert "test|demo" in routine["trigger_summary"] @pytest.mark.asyncio -async def test_concurrent_limit_prevents_additional_fires(browser_and_context): - """Test that concurrent limit is enforced via batch counts.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create routine with max_concurrent=1 - routine_id = await helper.create_event_routine( - name="Limited Handler", - trigger_regex="process|task", - max_concurrent=1, - ) - - # Trigger first message - await helper.send_chat_message("Process message 1") - await asyncio.sleep(1) - - # Check first execution logged - executions_1 = await helper.get_routine_execution_log(routine_id) - assert len(executions_1) >= 1 - - # Trigger second message while first is still running - sse_events = await helper.send_chat_message("Process message 2") - - # Second routine should be skipped (concurrent limit) - routine_skipped = any( - event.get("type") == "routine_skipped" - and event.get("reason") == "max_concurrent_reached" - and event.get("routine_id") == routine_id - for event in sse_events - ) - assert routine_skipped, "Routine should be skipped when concurrent limit reached" - - finally: - await page.close() +async def test_event_trigger_fires_on_matching_message( + page, + ironclaw_server, + http_channel_server, +): + """Matching HTTP-channel messages create routine runs.""" + name = f"evt-{uuid.uuid4().hex[:8]}" + routine = await _create_event_routine( + page, + ironclaw_server, + name=name, + pattern="urgent|critical|alert", + ) + + response = await _post_http_message( + http_channel_server, + content="urgent: server down", + ) + assert response["status"] == "accepted" + + await _wait_for_run_count( + ironclaw_server, + routine["id"], + expected_at_least=1, + ) + completed_run = await _wait_for_completed_run(ironclaw_server, routine["id"]) + + assert completed_run["status"].lower() == "attention" + assert completed_run["trigger_type"] == "event" @pytest.mark.asyncio -async def test_rapid_messages_with_multiple_triggers_efficiency(browser_and_context): - """Test efficiency of batch loading with multiple rapid messages.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create 5 overlapping routines - routine_ids = await helper.create_multiple_routines( - base_name="Rapid", count=5, trigger_regex="test|demo|check" - ) - - # Send 10 matching messages rapidly - for i in range(10): - message = f"test message {i}" - await helper.send_chat_message(message) - await asyncio.sleep(0.1) - - # Check database logs for query efficiency - batch_count, count_query_count = await helper.check_database_queries_in_logs() - - # After fix: should have ~10 batch queries (1 per message) - # Before fix: would have ~50 individual COUNT queries (5 routines × 10 messages) - assert ( - count_query_count == 0 - ), f"Should have 0 individual COUNT queries after fix, got {count_query_count}" - assert ( - batch_count <= 15 - ), f"Should have <=15 batch queries for 10 messages, got {batch_count}" - - finally: - await page.close() +async def test_event_trigger_skips_non_matching_message( + page, + ironclaw_server, + http_channel_server, +): + """Non-matching messages do not create routine runs.""" + name = f"evt-{uuid.uuid4().hex[:8]}" + routine = await _create_event_routine( + page, + ironclaw_server, + name=name, + pattern="urgent|critical|alert", + ) + + await _post_http_message( + http_channel_server, + content="hello there", + ) + await asyncio.sleep(2) + + assert await _get_routine_runs(ironclaw_server, routine["id"]) == [] @pytest.mark.asyncio -async def test_channel_filter_applied_correctly(browser_and_context): - """Test that channel filter prevents non-matching messages.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create routine for Slack channel - slack_routine_id = await helper.create_event_routine( - name="Slack Handler", - trigger_regex="alert", - channel="slack", +async def test_multiple_routines_fire_on_matching_message( + page, + ironclaw_server, + http_channel_server, +): + """A single matching message can fire multiple event routines.""" + routines = [] + for _ in range(3): + name = f"evt-{uuid.uuid4().hex[:8]}" + routines.append( + await _create_event_routine( + page, + ironclaw_server, + name=name, + pattern="error|warning|alert", + ) ) - # Simulate message from Telegram channel - # (Note: In real UI, would need to change channel context) - page.goto( - "http://localhost:8000/chat?channel=telegram" - ) # Switch channel - await helper.send_chat_message("alert: something urgent") - - # Routine should not fire (different channel) - executions = await helper.get_routine_execution_log(slack_routine_id) - - # Check if any recent execution (last 5 min) exists - recent = [ - e - for e in executions - if (datetime.now() - datetime.fromisoformat(e["timestamp"])).total_seconds() - < 300 - ] - assert ( - len(recent) == 0 - ), "Routine should not fire for different channel" + await _post_http_message( + http_channel_server, + content="error: database connection failed", + ) - finally: - await page.close() - - -@pytest.mark.asyncio -async def test_batch_query_failure_handling(browser_and_context): - """Test graceful handling of batch query failures.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create routine - routine_id = await helper.create_event_routine( - name="Error Handler", - trigger_regex="test", + for routine in routines: + await _wait_for_run_count( + ironclaw_server, + routine["id"], + expected_at_least=1, ) - - # Simulate database error in logs (if possible with test hooks) - # For now, just verify error handling doesn't crash UI - await helper.send_chat_message("test message") - - # Check that UI remains responsive - assert await page.locator("text=Message processed").is_visible() - - finally: - await page.close() + completed_run = await _wait_for_completed_run(ironclaw_server, routine["id"]) + assert completed_run["status"].lower() == "attention" @pytest.mark.asyncio -async def test_routine_execution_history_display(browser_and_context): - """Test that execution history correctly displays routine firings.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create routine - routine_id = await helper.create_event_routine( - name="History Test", - trigger_regex="test", - ) - - # Trigger routine 3 times - for i in range(3): - await helper.send_chat_message(f"test message {i}") - await asyncio.sleep(0.2) - - # Check execution log - executions = await helper.get_routine_execution_log(routine_id) - assert len(executions) >= 3, "Should have at least 3 executions logged" - - # Verify all are recent (within last 5 minutes) - for execution in executions[:3]: - timestamp = datetime.fromisoformat(execution["timestamp"]) - age = datetime.now() - timestamp - assert age < timedelta(minutes=5), "Execution should be recent" - - finally: - await page.close() +async def test_channel_filter_applied_correctly( + page, + ironclaw_server, + http_channel_server, +): + """Channel filters prevent HTTP messages from firing non-HTTP routines.""" + http_routine = await _create_event_routine( + page, + ironclaw_server, + name=f"evt-{uuid.uuid4().hex[:8]}", + pattern="alert", + channel="http", + ) + telegram_routine = await _create_event_routine( + page, + ironclaw_server, + name=f"evt-{uuid.uuid4().hex[:8]}", + pattern="alert", + channel="telegram", + ) + + await _post_http_message( + http_channel_server, + content="alert from webhook", + ) + + await _wait_for_run_count( + ironclaw_server, + http_routine["id"], + expected_at_least=1, + ) + http_run = await _wait_for_completed_run(ironclaw_server, http_routine["id"]) + await asyncio.sleep(2) + telegram_runs = await _get_routine_runs(ironclaw_server, telegram_routine["id"]) + + assert http_run["status"].lower() == "attention" + assert telegram_runs == [] @pytest.mark.asyncio -async def test_concurrent_batch_loads_independent(browser_and_context): - """Test that concurrent messages each get independent batch queries.""" - browser, context = browser_and_context - page = await context.new_page() - helper = EventTriggerHelper(page) - - try: - # Create 5 routines matching different patterns - r1_id = await helper.create_event_routine( - name="Pattern A", trigger_regex="alpha|alpha_only" - ) - r2_id = await helper.create_event_routine( - name="Pattern B", trigger_regex="beta|beta_only" - ) - r3_id = await helper.create_event_routine( - name="Pattern AB", trigger_regex="alpha|beta|common" - ) - - # Send overlapping messages - # Message 1: matches r1, r3 - sse1 = await helper.send_chat_message("alpha common") - await asyncio.sleep(0.1) - - # Message 2: matches r2, r3 - sse2 = await helper.send_chat_message("beta common") - await asyncio.sleep(0.1) - - # Verify correct routines fired - r1_fired_msg1 = any( - e.get("routine_id") == r1_id for e in sse1 if e.get("type") == "routine_fired" - ) - r2_fired_msg2 = any( - e.get("routine_id") == r2_id for e in sse2 if e.get("type") == "routine_fired" - ) - r3_fired_both = ( - any( - e.get("routine_id") == r3_id for e in sse1 if e.get("type") == "routine_fired" - ) - and any( - e.get("routine_id") == r3_id for e in sse2 if e.get("type") == "routine_fired" - ) - ) - - assert r1_fired_msg1, "Routine 1 should fire on message 1" - assert r2_fired_msg2, "Routine 2 should fire on message 2" - assert r3_fired_both, "Routine 3 should fire on both messages" - - finally: - await page.close() - - -# ============================================================================= -# Integration with existing test patterns -# ============================================================================= - - -if __name__ == "__main__": - # Run tests with: pytest tests/e2e/scenarios/test_routine_event_batch.py -v - pytest.main([__file__, "-v", "-s"]) +async def test_routine_execution_history_is_available( + page, + ironclaw_server, + http_channel_server, +): + """Routine run history is exposed by the routines runs API.""" + routine = await _create_event_routine( + page, + ironclaw_server, + name=f"evt-{uuid.uuid4().hex[:8]}", + pattern="history", + ) + + await _post_http_message( + http_channel_server, + content="history event", + ) + + await _wait_for_run_count( + ironclaw_server, + routine["id"], + expected_at_least=1, + ) + completed_run = await _wait_for_completed_run(ironclaw_server, routine["id"]) + + assert completed_run["id"] + assert completed_run["started_at"] + assert completed_run["status"].lower() == "attention" diff --git a/tests/e2e/scenarios/test_webhook.py b/tests/e2e/scenarios/test_webhook.py index c0227c97ec..e6f9b26e70 100644 --- a/tests/e2e/scenarios/test_webhook.py +++ b/tests/e2e/scenarios/test_webhook.py @@ -7,7 +7,7 @@ import httpx import pytest -from helpers import AUTH_TOKEN +from helpers import HTTP_WEBHOOK_SECRET def compute_signature(secret: str, body: bytes) -> str: @@ -16,325 +16,188 @@ def compute_signature(secret: str, body: bytes) -> str: return f"sha256={mac.hexdigest()}" -@pytest.mark.asyncio -async def test_webhook_requires_http_webhook_secret_configured(ironclaw_server): - """ - Webhook endpoint rejects requests when HTTP_WEBHOOK_SECRET is not configured. - This tests the fail-closed security posture. - """ - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - async with httpx.AsyncClient() as client: - # When no webhook secret is configured on the server, all requests fail - r = await client.post( - f"{ironclaw_server}/webhook", - json={"content": "test message"}, - headers=headers, - ) - # Server should reject with 503 Service Unavailable (fail closed) - assert r.status_code in (401, 503) - - -@pytest.mark.asyncio -async def test_webhook_hmac_signature_valid(ironclaw_server_with_webhook_secret): - """Valid X-Hub-Signature-256 HMAC signature is accepted.""" - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] - - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello from webhook"} +async def _post_webhook( + base_url: str, + body_data: dict, + *, + signature: str | None = None, + content_type: str = "application/json", +) -> httpx.Response: + """Send a raw webhook request with optional signature.""" body_bytes = json.dumps(body_data).encode() - signature = compute_signature(secret, body_bytes) + headers = {"Content-Type": content_type} + if signature is not None: + headers["X-Hub-Signature-256"] = signature async with httpx.AsyncClient() as client: - r = await client.post( + return await client.post( f"{base_url}/webhook", content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - "X-Hub-Signature-256": signature, - }, + headers=headers, ) - assert r.status_code == 200, f"Expected 200, got {r.status_code}: {r.text}" - resp = r.json() - assert resp["status"] == "ok" @pytest.mark.asyncio -async def test_webhook_invalid_hmac_signature_rejected( - ironclaw_server_with_webhook_secret, +async def test_webhook_requires_http_webhook_secret_configured( + http_channel_server_without_secret, ): - """Invalid X-Hub-Signature-256 signature is rejected with 401.""" - base_url = ironclaw_server_with_webhook_secret["url"] + """Webhook fails closed when no secret is configured.""" + response = await _post_webhook( + http_channel_server_without_secret, + {"content": "test message"}, + ) - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello"} - body_bytes = json.dumps(body_data).encode() - invalid_signature = "sha256=0000000000000000000000000000000000000000000000000000000000000000" - - async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - "X-Hub-Signature-256": invalid_signature, - }, - ) - assert r.status_code == 401, f"Expected 401, got {r.status_code}" - resp = r.json() - assert resp["status"] == "error" - assert "Invalid webhook signature" in resp.get("response", "") + assert response.status_code == 503 + data = response.json() + assert data["status"] == "error" + assert "Webhook authentication not configured" in data.get("response", "") @pytest.mark.asyncio -async def test_webhook_wrong_secret_rejected(ironclaw_server_with_webhook_secret): - """Signature computed with wrong secret is rejected.""" - base_url = ironclaw_server_with_webhook_secret["url"] +async def test_webhook_hmac_signature_valid(http_channel_server): + """Valid X-Hub-Signature-256 HMAC signature is accepted.""" + body = {"content": "hello from webhook"} + signature = compute_signature(HTTP_WEBHOOK_SECRET, json.dumps(body).encode()) - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello"} - body_bytes = json.dumps(body_data).encode() - # Compute signature with wrong secret - wrong_signature = compute_signature("wrong-secret", body_bytes) + response = await _post_webhook(http_channel_server, body, signature=signature) - async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - "X-Hub-Signature-256": wrong_signature, - }, - ) - assert r.status_code == 401 - resp = r.json() - assert resp["status"] == "error" + assert response.status_code == 200, ( + f"Expected 200, got {response.status_code}: {response.text}" + ) + data = response.json() + assert data["status"] == "accepted" @pytest.mark.asyncio -async def test_webhook_malformed_signature_rejected( - ironclaw_server_with_webhook_secret, -): - """Malformed X-Hub-Signature-256 header is rejected.""" - base_url = ironclaw_server_with_webhook_secret["url"] - - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello"} - body_bytes = json.dumps(body_data).encode() +async def test_webhook_invalid_hmac_signature_rejected(http_channel_server): + """Invalid X-Hub-Signature-256 signature is rejected with 401.""" + response = await _post_webhook( + http_channel_server, + {"content": "hello"}, + signature="sha256=0000000000000000000000000000000000000000000000000000000000000000", + ) - async with httpx.AsyncClient() as client: - # Missing sha256= prefix - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - "X-Hub-Signature-256": "deadbeef", - }, - ) - assert r.status_code == 401 + assert response.status_code == 401 + data = response.json() + assert data["status"] == "error" + assert "Invalid webhook signature" in data.get("response", "") @pytest.mark.asyncio -async def test_webhook_missing_signature_header_rejected( - ironclaw_server_with_webhook_secret, -): - """Missing X-Hub-Signature-256 header is rejected when no body secret provided.""" - base_url = ironclaw_server_with_webhook_secret["url"] +async def test_webhook_wrong_secret_rejected(http_channel_server): + """Signature computed with wrong secret is rejected.""" + body = {"content": "hello"} + signature = compute_signature("wrong-secret", json.dumps(body).encode()) - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello"} - body_bytes = json.dumps(body_data).encode() + response = await _post_webhook(http_channel_server, body, signature=signature) - async with httpx.AsyncClient() as client: - # No X-Hub-Signature-256 header and no body secret - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - }, - ) - assert r.status_code == 401 - resp = r.json() - assert "Webhook authentication required" in resp.get("response", "") - assert "X-Hub-Signature-256" in resp.get("response", "") + assert response.status_code == 401 + assert response.json()["status"] == "error" @pytest.mark.asyncio -async def test_webhook_deprecated_body_secret_still_works( - ironclaw_server_with_webhook_secret, -): - """ - Deprecated: body 'secret' field still works for backward compatibility. - This test ensures we don't break existing clients during the migration period. - """ - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] - - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - # Old-style request with secret in body - body_data = {"content": "hello", "secret": secret} - body_bytes = json.dumps(body_data).encode() +async def test_webhook_missing_signature_header_rejected(http_channel_server): + """Missing X-Hub-Signature-256 header is rejected when no body secret is provided.""" + response = await _post_webhook(http_channel_server, {"content": "hello"}) - async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - }, - ) - # Should succeed (backward compatibility) - assert r.status_code == 200, f"Expected 200, got {r.status_code}: {r.text}" - resp = r.json() - assert resp["status"] == "ok" + assert response.status_code == 401 + data = response.json() + assert "Webhook authentication required" in data.get("response", "") + assert "X-Hub-Signature-256" in data.get("response", "") @pytest.mark.asyncio -async def test_webhook_header_takes_precedence_over_body_secret( - ironclaw_server_with_webhook_secret, -): - """ - When both X-Hub-Signature-256 header and body secret are provided, - header takes precedence. - """ - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] - - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello", "secret": "wrong-secret-in-body"} - body_bytes = json.dumps(body_data).encode() - # Compute signature with correct secret - signature = compute_signature(secret, body_bytes) +async def test_webhook_deprecated_body_secret_still_works(http_channel_server): + """Deprecated body secret support still accepts old clients.""" + response = await _post_webhook( + http_channel_server, + {"content": "hello", "secret": HTTP_WEBHOOK_SECRET}, + ) - async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - "X-Hub-Signature-256": signature, - }, - ) - # Should succeed because header signature is valid (takes precedence) - assert r.status_code == 200 - resp = r.json() - assert resp["status"] == "ok" + assert response.status_code == 200, ( + f"Expected 200, got {response.status_code}: {response.text}" + ) + assert response.json()["status"] == "accepted" @pytest.mark.asyncio -async def test_webhook_case_insensitive_header_lookup( - ironclaw_server_with_webhook_secret, -): - """HTTP headers are case-insensitive. Test with different cases.""" - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] +async def test_webhook_header_takes_precedence_over_body_secret(http_channel_server): + """Header signature wins when both header and body secret are provided.""" + body = {"content": "hello", "secret": "wrong-secret-in-body"} + signature = compute_signature(HTTP_WEBHOOK_SECRET, json.dumps(body).encode()) - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello"} - body_bytes = json.dumps(body_data).encode() - signature = compute_signature(secret, body_bytes) + response = await _post_webhook(http_channel_server, body, signature=signature) + + assert response.status_code == 200 + assert response.json()["status"] == "accepted" + + +@pytest.mark.asyncio +async def test_webhook_case_insensitive_header_lookup(http_channel_server): + """HTTP headers are treated case-insensitively.""" + body = {"content": "hello"} + body_bytes = json.dumps(body).encode() + signature = compute_signature(HTTP_WEBHOOK_SECRET, body_bytes) async with httpx.AsyncClient() as client: - # Try with lowercase - r = await client.post( - f"{base_url}/webhook", + response = await client.post( + f"{http_channel_server}/webhook", content=body_bytes, headers={ - **headers, "Content-Type": "application/json", "x-hub-signature-256": signature, }, ) - assert r.status_code == 200 + + assert response.status_code == 200 @pytest.mark.asyncio -async def test_webhook_wrong_content_type_rejected( - ironclaw_server_with_webhook_secret, -): +async def test_webhook_wrong_content_type_rejected(http_channel_server): """Webhook only accepts application/json Content-Type.""" - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] + body = {"content": "hello"} + signature = compute_signature(HTTP_WEBHOOK_SECRET, json.dumps(body).encode()) - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - body_data = {"content": "hello"} - body_bytes = json.dumps(body_data).encode() - signature = compute_signature(secret, body_bytes) + response = await _post_webhook( + http_channel_server, + body, + signature=signature, + content_type="text/plain", + ) - async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "text/plain", - "X-Hub-Signature-256": signature, - }, - ) - assert r.status_code == 415 # Unsupported Media Type - resp = r.json() - assert "application/json" in resp.get("response", "") + assert response.status_code == 415 + assert "application/json" in response.json().get("response", "") @pytest.mark.asyncio -async def test_webhook_invalid_json_rejected(ironclaw_server_with_webhook_secret): +async def test_webhook_invalid_json_rejected(http_channel_server): """Invalid JSON in body is rejected.""" - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] - - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} body_bytes = b"not valid json" - signature = compute_signature(secret, body_bytes) + signature = compute_signature(HTTP_WEBHOOK_SECRET, body_bytes) async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", + response = await client.post( + f"{http_channel_server}/webhook", content=body_bytes, headers={ - **headers, "Content-Type": "application/json", "X-Hub-Signature-256": signature, }, ) - assert r.status_code == 401 or r.status_code == 400 + assert response.status_code in (400, 401) -@pytest.mark.asyncio -async def test_webhook_message_queued_for_processing( - ironclaw_server_with_webhook_secret, -): - """Message via webhook is queued and can be retrieved.""" - secret = ironclaw_server_with_webhook_secret["secret"] - base_url = ironclaw_server_with_webhook_secret["url"] - - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} - test_message = "webhook test message 12345" - body_data = {"content": test_message} - body_bytes = json.dumps(body_data).encode() - signature = compute_signature(secret, body_bytes) - async with httpx.AsyncClient() as client: - r = await client.post( - f"{base_url}/webhook", - content=body_bytes, - headers={ - **headers, - "Content-Type": "application/json", - "X-Hub-Signature-256": signature, - }, - ) - assert r.status_code == 200 - resp = r.json() - assert resp["status"] == "ok" - # Message ID should be present - assert "message_id" in resp - assert resp["message_id"] != "00000000-0000-0000-0000-000000000000" +@pytest.mark.asyncio +async def test_webhook_message_queued_for_processing(http_channel_server): + """Accepted webhook requests return a real message id.""" + body = {"content": "webhook test message 12345"} + signature = compute_signature(HTTP_WEBHOOK_SECRET, json.dumps(body).encode()) + + response = await _post_webhook(http_channel_server, body, signature=signature) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "accepted" + assert "message_id" in data + assert data["message_id"] != "00000000-0000-0000-0000-000000000000" diff --git a/tests/e2e_builtin_tool_coverage.rs b/tests/e2e_builtin_tool_coverage.rs index 4da65c23cc..2a97a0d503 100644 --- a/tests/e2e_builtin_tool_coverage.rs +++ b/tests/e2e_builtin_tool_coverage.rs @@ -155,7 +155,7 @@ mod tests { } assert_eq!(routine.notify.channel.as_deref(), Some("telegram")); - assert_eq!(routine.notify.user, "ops-team"); + assert_eq!(routine.notify.user.as_deref(), Some("ops-team")); assert_eq!(routine.guardrails.cooldown.as_secs(), 600); rig.shutdown(); diff --git a/tests/e2e_routine_heartbeat.rs b/tests/e2e_routine_heartbeat.rs index 6d6deb8bec..48fb1ef462 100644 --- a/tests/e2e_routine_heartbeat.rs +++ b/tests/e2e_routine_heartbeat.rs @@ -48,6 +48,19 @@ mod tests { Arc::new(Workspace::new_with_db("default", db.clone())) } + fn make_message( + channel: &str, + user_id: &str, + owner_id: &str, + sender_id: &str, + content: &str, + ) -> IncomingMessage { + IncomingMessage::new(channel, user_id, content) + .with_owner_id(owner_id) + .with_sender_id(sender_id) + .with_metadata(serde_json::json!({})) + } + /// Helper to insert a routine directly into the database. fn make_routine(name: &str, trigger: Trigger, prompt: &str) -> Routine { Routine { @@ -218,7 +231,13 @@ mod tests { engine.refresh_event_cache().await; // Positive match: message containing "deploy to production". - let matching_msg = IncomingMessage::new("test", "default", "deploy to production now"); + let matching_msg = make_message( + "test", + "default", + "default", + "default", + "deploy to production now", + ); let fired = engine.check_event_triggers(&matching_msg).await; assert!( fired >= 1, @@ -229,12 +248,114 @@ mod tests { tokio::time::sleep(Duration::from_millis(500)).await; // Negative match: message that doesn't match. - let non_matching_msg = - IncomingMessage::new("test", "default", "check the staging environment"); + let non_matching_msg = make_message( + "test", + "default", + "default", + "default", + "check the staging environment", + ); let fired_neg = engine.check_event_triggers(&non_matching_msg).await; assert_eq!(fired_neg, 0, "Expected 0 routines fired on non-match"); } + #[tokio::test] + async fn event_trigger_respects_message_user_scope() { + let (db, _tmp) = create_test_db().await; + let ws = create_workspace(&db); + + let trace = LlmTrace::single_turn( + "test-event-user-scope", + "deploy", + vec![TraceStep { + request_hint: None, + response: TraceResponse::Text { + content: "Owner event handled".to_string(), + input_tokens: 50, + output_tokens: 8, + }, + expected_tool_results: vec![], + }], + ); + let llm = Arc::new(TraceLlm::from_trace(trace)); + let (notify_tx, _notify_rx) = tokio::sync::mpsc::channel(16); + + let tools = Arc::new(ToolRegistry::new()); + let safety = Arc::new(SafetyLayer::new(&SafetyConfig { + max_output_length: 100_000, + injection_check_enabled: true, + })); + + let engine = Arc::new(RoutineEngine::new( + RoutineConfig::default(), + db.clone(), + llm, + ws, + notify_tx, + None, + tools, + safety, + )); + + let routine = make_routine( + "owner-deploy-watcher", + Trigger::Event { + channel: None, + pattern: "deploy.*production".to_string(), + }, + "Report on deployment.", + ); + db.create_routine(&routine).await.expect("create_routine"); + engine.refresh_event_cache().await; + + let guest_msg = make_message( + "telegram", + "guest", + "default", + "guest-sender", + "deploy to production now", + ); + let guest_fired = engine.check_event_triggers(&guest_msg).await; + assert_eq!( + guest_fired, 0, + "Guest scope must not fire owner event routines" + ); + tokio::time::sleep(Duration::from_millis(200)).await; + + let guest_runs = db + .list_routine_runs(routine.id, 10) + .await + .expect("list_routine_runs after guest message"); + assert!( + guest_runs.is_empty(), + "Guest message should not create routine runs" + ); + + let owner_msg = make_message( + "telegram", + "default", + "default", + "owner-sender", + "deploy to production now", + ); + let owner_fired = engine.check_event_triggers(&owner_msg).await; + assert!( + owner_fired >= 1, + "Owner scope should fire matching owner event routine" + ); + tokio::time::sleep(Duration::from_millis(500)).await; + + let owner_runs = db + .list_routine_runs(routine.id, 10) + .await + .expect("list_routine_runs after owner message"); + assert_eq!( + owner_runs.len(), + 1, + "Owner message should create exactly one run" + ); + } + // ----------------------------------------------------------------------- // Test 3: system_event_trigger_matches_and_filters // ----------------------------------------------------------------------- @@ -434,7 +555,13 @@ mod tests { engine.refresh_event_cache().await; // First fire should work. - let msg = IncomingMessage::new("test", "default", "test-cooldown trigger"); + let msg = make_message( + "test", + "default", + "default", + "default", + "test-cooldown trigger", + ); let fired1 = engine.check_event_triggers(&msg).await; assert!(fired1 >= 1, "First fire should work"); @@ -553,4 +680,118 @@ mod tests { "Expected Skipped for empty checklist, got: {result:?}" ); } + + /// Helper to set up a test environment for routine engine mutation tests. + /// Returns the engine, database, and temp directory. + async fn setup_routine_mutation_test() + -> (Arc, Arc, tempfile::TempDir) { + let (db, dir) = create_test_db().await; + let ws = create_workspace(&db); + let (notify_tx, _rx) = tokio::sync::mpsc::channel(16); + let tools = Arc::new(ToolRegistry::new()); + + let safety_config = SafetyConfig { + max_output_length: 100_000, + injection_check_enabled: true, + }; + let safety = Arc::new(SafetyLayer::new(&safety_config)); + + let trace = LlmTrace::single_turn( + "test-routine-mutation", + "test", + vec![TraceStep { + request_hint: None, + response: TraceResponse::Text { + content: "ROUTINE_OK".to_string(), + input_tokens: 50, + output_tokens: 5, + }, + expected_tool_results: vec![], + }], + ); + let llm = Arc::new(TraceLlm::from_trace(trace)); + + let engine = Arc::new(RoutineEngine::new( + RoutineConfig::default(), + Arc::clone(&db), + llm, + ws, + notify_tx, + None, + tools, + safety, + )); + + (engine, db, dir) + } + + /// Regression test for issue #1076: disabling an event routine via a DB mutation + /// followed by refresh_event_cache() (the path now taken by the web toggle handler) + /// must immediately stop the routine from firing. + #[tokio::test] + async fn toggle_disabling_event_routine_removes_from_cache() { + let (engine, db, _dir) = setup_routine_mutation_test().await; + + // Create and cache an event routine. + let mut routine = make_routine( + "disable-me", + Trigger::Event { + pattern: "DISABLE_ME".to_string(), + channel: None, + }, + "Handle DISABLE_ME event", + ); + db.create_routine(&routine).await.expect("create_routine"); + engine.refresh_event_cache().await; + + let msg = IncomingMessage::new("test", "default", "DISABLE_ME"); + let fired_before = engine.check_event_triggers(&msg).await; + assert!(fired_before >= 1, "Expected routine to fire before disable"); + + // Simulate what routines_toggle_handler now does: update DB, then refresh. + routine.enabled = false; + routine.updated_at = Utc::now(); + db.update_routine(&routine).await.expect("update_routine"); + engine.refresh_event_cache().await; + + let fired_after = engine.check_event_triggers(&msg).await; + assert_eq!( + fired_after, 0, + "Disabled routine must not fire after cache refresh" + ); + } + + /// Regression test for issue #1076: deleting an event routine via a DB mutation + /// followed by refresh_event_cache() must immediately stop the routine from firing. + #[tokio::test] + async fn delete_event_routine_removes_from_cache() { + let (engine, db, _dir) = setup_routine_mutation_test().await; + + let routine = make_routine( + "delete-me", + Trigger::Event { + pattern: "DELETE_ME".to_string(), + channel: None, + }, + "Handle DELETE_ME event", + ); + db.create_routine(&routine).await.expect("create_routine"); + engine.refresh_event_cache().await; + + let msg = IncomingMessage::new("test", "default", "DELETE_ME"); + assert!( + engine.check_event_triggers(&msg).await >= 1, + "Expected routine to fire before delete" + ); + + // Simulate what routines_delete_handler now does: delete from DB, then refresh. + db.delete_routine(routine.id).await.expect("delete_routine"); + engine.refresh_event_cache().await; + + assert_eq!( + engine.check_event_triggers(&msg).await, + 0, + "Deleted routine must not fire after cache refresh" + ); + } } diff --git a/tests/support/gateway_workflow_harness.rs b/tests/support/gateway_workflow_harness.rs index c539dad504..a4d737b52a 100644 --- a/tests/support/gateway_workflow_harness.rs +++ b/tests/support/gateway_workflow_harness.rs @@ -239,6 +239,7 @@ impl GatewayWorkflowHarness { let mut agent = Agent::new( components.config.agent.clone(), AgentDeps { + owner_id: components.config.owner_id.clone(), store: components.db, llm: components.llm, cheap_llm: components.cheap_llm, diff --git a/tests/support/test_rig.rs b/tests/support/test_rig.rs index 07106e428e..8549a21cb1 100644 --- a/tests/support/test_rig.rs +++ b/tests/support/test_rig.rs @@ -612,6 +612,7 @@ impl TestRigBuilder { // 7. Construct AgentDeps from AppComponents (mirrors main.rs). let deps = AgentDeps { + owner_id: components.config.owner_id.clone(), store: components.db, llm: components.llm, cheap_llm: components.cheap_llm, diff --git a/tests/telegram_auth_integration.rs b/tests/telegram_auth_integration.rs index 8b27d8a8c8..0052f8a24f 100644 --- a/tests/telegram_auth_integration.rs +++ b/tests/telegram_auth_integration.rs @@ -6,17 +6,21 @@ //! 1. When owner_id is null and dm_policy is "allowlist", unauthorized users in //! group chats are dropped even if they @mention the bot //! 2. When owner_id is null and dm_policy is "open", all users can interact -//! 3. When owner_id is set, only that user can interact +//! 3. When owner_id is set, the owner gets instance-global access while +//! non-owner senders remain channel-scoped guests subject to authorization //! 4. Authorization works correctly for both private and group chats use std::collections::HashMap; use std::sync::Arc; +use futures::StreamExt; +use ironclaw::channels::Channel; use ironclaw::channels::wasm::{ ChannelCapabilities, PreparedChannelModule, WasmChannel, WasmChannelRuntime, WasmChannelRuntimeConfig, }; use ironclaw::pairing::PairingStore; +use tokio::time::{Duration, timeout}; /// Skip the test if the Telegram WASM module hasn't been built. /// In CI (detected via the `CI` env var), panic instead of skipping so a @@ -97,6 +101,14 @@ async fn load_telegram_module( async fn create_telegram_channel( runtime: Arc, config_json: &str, +) -> WasmChannel { + create_telegram_channel_with_store(runtime, config_json, Arc::new(PairingStore::new())).await +} + +async fn create_telegram_channel_with_store( + runtime: Arc, + config_json: &str, + pairing_store: Arc, ) -> WasmChannel { let module = load_telegram_module(&runtime) .await @@ -106,8 +118,9 @@ async fn create_telegram_channel( runtime, module, ChannelCapabilities::for_channel("telegram").with_path("/webhook/telegram"), + "default", config_json.to_string(), - Arc::new(PairingStore::new()), + pairing_store, None, ) } @@ -245,31 +258,29 @@ async fn test_group_message_authorized_user_allowed() { } #[tokio::test] -async fn test_group_message_with_owner_id_set() { +async fn test_private_message_with_owner_id_set_uses_guest_pairing_flow() { require_telegram_wasm!(); let runtime = create_test_runtime(); + let dir = tempfile::tempdir().expect("tempdir"); + let pairing_store = Arc::new(PairingStore::with_base_dir(dir.path().to_path_buf())); - // Config: owner_id=123 (only this user can interact) + // Config: owner_id=123, non-owner private DMs should enter the guest + // pairing flow instead of being rejected solely for not being the owner. let config = serde_json::json!({ - "bot_username": "test_bot", + "bot_username": null, "owner_id": 123, - "dm_policy": "allowlist", - "allow_from": ["anyone"], // ignored when owner_id is set + "dm_policy": "pairing", + "allow_from": [], "respond_to_all_group_messages": false }) .to_string(); - let channel = create_telegram_channel(runtime, &config).await; + let channel = create_telegram_channel_with_store(runtime, &config, pairing_store.clone()).await; - // Message from different user (should be dropped) + // Non-owner private message should produce a pairing request. let update = build_telegram_update( - 3, - 102, - -123456789, - "group", - 999, // Not the owner - "Other", - "Hey @test_bot hello", + 3, 102, 999, "private", 999, // Not the owner + "Other", "hello", ); let response = channel @@ -286,8 +297,64 @@ async fn test_group_message_with_owner_id_set() { assert_eq!(response.status, 200); - // REGRESSION TEST: Non-owner messages are dropped when owner_id is set - // This behavior is consistent and not affected by the fix + let pending = pairing_store + .list_pending("telegram") + .expect("pairing store should be readable"); + assert_eq!(pending.len(), 1); + assert_eq!(pending[0].id, "999"); +} + +#[tokio::test] +async fn test_private_messages_use_chat_id_as_thread_scope() { + require_telegram_wasm!(); + let runtime = create_test_runtime(); + + let config = serde_json::json!({ + "bot_username": null, + "owner_id": null, + "dm_policy": "open", + "allow_from": [], + "respond_to_all_group_messages": false + }) + .to_string(); + + let channel = create_telegram_channel(runtime, &config).await; + let mut stream = channel.start().await.expect("Failed to start channel"); + + for (update_id, message_id, text) in [(6, 105, "first"), (7, 106, "second")] { + let update = build_telegram_update( + update_id, + message_id, + 999, + "private", + 999, + "ThreadUser", + text, + ); + + let response = channel + .call_on_http_request( + "POST", + "/webhook/telegram", + &HashMap::new(), + &HashMap::new(), + &update, + true, + ) + .await + .expect("HTTP callback failed"); + + assert_eq!(response.status, 200); + + let msg = timeout(Duration::from_secs(1), stream.next()) + .await + .expect("message should arrive") + .expect("stream should yield a message"); + assert_eq!(msg.thread_id.as_deref(), Some("999")); + assert_eq!(msg.conversation_scope(), Some("999")); + } + + channel.shutdown().await.expect("Shutdown failed"); } #[tokio::test] diff --git a/tests/wasm_channel_integration.rs b/tests/wasm_channel_integration.rs index b5d1785b94..7e05c0f397 100644 --- a/tests/wasm_channel_integration.rs +++ b/tests/wasm_channel_integration.rs @@ -43,6 +43,7 @@ fn create_test_channel( runtime, prepared, capabilities, + "default", "{}".to_string(), Arc::new(PairingStore::new()), None,