diff --git a/.env.example b/.env.example index 765ea3f652..55c3adb52a 100644 --- a/.env.example +++ b/.env.example @@ -18,6 +18,11 @@ DATABASE_POOL_SIZE=10 # === OpenAI Direct === # OPENAI_API_KEY=sk-... +# Reuse Codex CLI auth.json instead of setting OPENAI_API_KEY manually. +# Works with both OpenAI API-key mode and Codex ChatGPT OAuth mode. +# In ChatGPT mode this uses the private `chatgpt.com/backend-api/codex` endpoint. +# LLM_USE_CODEX_AUTH=true +# CODEX_AUTH_PATH=~/.codex/auth.json # === NEAR AI (Chat Completions API) === # Two auth modes: diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 92f203b36a..ee16c0f8df 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -52,7 +52,7 @@ jobs: - group: features files: "tests/e2e/scenarios/test_skills.py tests/e2e/scenarios/test_tool_approval.py" - group: extensions - files: "tests/e2e/scenarios/test_extensions.py tests/e2e/scenarios/test_extension_oauth.py tests/e2e/scenarios/test_wasm_lifecycle.py tests/e2e/scenarios/test_tool_execution.py tests/e2e/scenarios/test_pairing.py tests/e2e/scenarios/test_oauth_credential_fallback.py tests/e2e/scenarios/test_routine_oauth_credential_injection.py" + files: "tests/e2e/scenarios/test_extensions.py tests/e2e/scenarios/test_extension_oauth.py tests/e2e/scenarios/test_telegram_token_validation.py tests/e2e/scenarios/test_wasm_lifecycle.py tests/e2e/scenarios/test_tool_execution.py tests/e2e/scenarios/test_pairing.py tests/e2e/scenarios/test_oauth_credential_fallback.py tests/e2e/scenarios/test_routine_oauth_credential_injection.py" steps: - uses: actions/checkout@v6 diff --git a/.gitignore b/.gitignore index ed64c2423b..2577b4a278 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,9 @@ trace_*.json # Local Claude Code settings (machine-specific, should not be committed) .claude/settings.local.json .worktrees/ + +# Python cache +__pycache__/ +*.pyc +*.pyo +*.pyd diff --git a/Cargo.lock b/Cargo.lock index dab77b8d38..854d103abf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3461,6 +3461,7 @@ dependencies = [ "dirs 6.0.0", "dotenvy", "ed25519-dalek", + "eventsource-stream", "flate2", "fs4", "futures", @@ -4364,9 +4365,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.75" +version = "0.10.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" dependencies = [ "bitflags 2.11.0", "cfg-if", @@ -4402,9 +4403,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.111" +version = "0.9.112" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" dependencies = [ "cc", "libc", diff --git a/Cargo.toml b/Cargo.toml index 122c90ec34..aef4e6879a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ eula = false tokio = { version = "1", features = ["full"] } tokio-stream = { version = "0.1", features = ["sync"] } futures = "0.3" +eventsource-stream = "0.2" # HTTP client reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-native-roots", "stream"] } diff --git a/FEATURE_PARITY.md b/FEATURE_PARITY.md index db4ab92a4c..0cda8caaac 100644 --- a/FEATURE_PARITY.md +++ b/FEATURE_PARITY.md @@ -68,7 +68,7 @@ This document tracks feature parity between IronClaw (Rust implementation) and O | REPL (simple) | ✅ | ✅ | - | For testing | | WASM channels | ❌ | ✅ | - | IronClaw innovation | | WhatsApp | ✅ | ❌ | P1 | Baileys (Web), same-phone mode with echo detection | -| Telegram | ✅ | ✅ | - | WASM channel(MTProto), DM pairing, caption, /start, bot_username, DM topics | +| Telegram | ✅ | ✅ | - | WASM channel(MTProto), DM pairing, caption, /start, bot_username, DM topics, setup-time owner verification | | Discord | ✅ | ❌ | P2 | discord.js, thread parent binding inheritance | | Signal | ✅ | ✅ | P2 | signal-cli daemonPC, SSE listener HTTP/JSON-R, user/group allowlists, DM pairing | | Slack | ✅ | ✅ | - | WASM tool | diff --git a/README.md b/README.md index b18d0d7d1a..9684ee4de6 100644 --- a/README.md +++ b/README.md @@ -166,13 +166,20 @@ written to `~/.ironclaw/.env` so they are available before the database connects ### Alternative LLM Providers -IronClaw defaults to NEAR AI but works with any OpenAI-compatible endpoint. -Popular options include **OpenRouter** (300+ models), **Together AI**, **Fireworks AI**, -**Ollama** (local), and self-hosted servers like **vLLM** or **LiteLLM**. +IronClaw defaults to NEAR AI but supports many LLM providers out of the box. +Built-in providers include **Anthropic**, **OpenAI**, **Google Gemini**, **MiniMax**, +**Mistral**, and **Ollama** (local). OpenAI-compatible services like **OpenRouter** +(300+ models), **Together AI**, **Fireworks AI**, and self-hosted servers (**vLLM**, +**LiteLLM**) are also supported. -Select *"OpenAI-compatible"* in the wizard, or set environment variables directly: +Select your provider in the wizard, or set environment variables directly: ```env +# Example: MiniMax (built-in, 204K context) +LLM_BACKEND=minimax +MINIMAX_API_KEY=... + +# Example: OpenAI-compatible endpoint LLM_BACKEND=openai_compatible LLM_BASE_URL=https://openrouter.ai/api/v1 LLM_API_KEY=sk-or-... diff --git a/README.ru.md b/README.ru.md index b534f0e503..c64770a96b 100644 --- a/README.ru.md +++ b/README.ru.md @@ -163,12 +163,20 @@ ironclaw onboard ### Альтернативные LLM-провайдеры -IronClaw по умолчанию использует NEAR AI, но работает с любыми OpenAI-совместимыми эндпоинтами. -Популярные варианты включают **OpenRouter** (300+ моделей), **Together AI**, **Fireworks AI**, **Ollama** (локально) и собственные серверы, такие как **vLLM** или **LiteLLM**. +IronClaw по умолчанию использует NEAR AI, но поддерживает множество LLM-провайдеров из коробки. +Встроенные провайдеры включают **Anthropic**, **OpenAI**, **Google Gemini**, **MiniMax**, +**Mistral** и **Ollama** (локально). Также поддерживаются OpenAI-совместимые сервисы: +**OpenRouter** (300+ моделей), **Together AI**, **Fireworks AI** и собственные серверы +(**vLLM**, **LiteLLM**). -Выберите *"OpenAI-compatible"* в мастере настройки или установите переменные окружения напрямую: +Выберите провайдера в мастере настройки или установите переменные окружения напрямую: ```env +# Пример: MiniMax (встроенный, контекст 204K) +LLM_BACKEND=minimax +MINIMAX_API_KEY=... + +# Пример: OpenAI-совместимый эндпоинт LLM_BACKEND=openai_compatible LLM_BASE_URL=https://openrouter.ai/api/v1 LLM_API_KEY=sk-or-... diff --git a/README.zh-CN.md b/README.zh-CN.md index c51afc60bc..3402382227 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -163,12 +163,17 @@ ironclaw onboard ### 替代 LLM 提供商 -IronClaw 默认使用 NEAR AI,但兼容任何 OpenAI 兼容的端点。 -常用选项包括 **OpenRouter**(300+ 模型)、**Together AI**、**Fireworks AI**、**Ollama**(本地部署)以及自托管服务器如 **vLLM** 或 **LiteLLM**。 +IronClaw 默认使用 NEAR AI,但开箱即用地支持多种 LLM 提供商。 +内置提供商包括 **Anthropic**、**OpenAI**、**Google Gemini**、**MiniMax**、**Mistral** 和 **Ollama**(本地部署)。同时也支持 OpenAI 兼容服务,如 **OpenRouter**(300+ 模型)、**Together AI**、**Fireworks AI** 以及自托管服务器(**vLLM**、**LiteLLM**)。 -在向导中选择 *"OpenAI-compatible"*,或直接设置环境变量: +在向导中选择你的提供商,或直接设置环境变量: ```env +# 示例:MiniMax(内置,204K 上下文) +LLM_BACKEND=minimax +MINIMAX_API_KEY=... + +# 示例:OpenAI 兼容端点 LLM_BACKEND=openai_compatible LLM_BASE_URL=https://openrouter.ai/api/v1 LLM_API_KEY=sk-or-... diff --git a/channels-src/telegram/src/lib.rs b/channels-src/telegram/src/lib.rs index d8718ebb91..936197bc04 100644 --- a/channels-src/telegram/src/lib.rs +++ b/channels-src/telegram/src/lib.rs @@ -100,6 +100,15 @@ struct TelegramMessage { /// Sticker. sticker: Option, + + /// Forum topic ID. Present when the message is sent inside a forum topic. + /// https://core.telegram.org/bots/api#message + #[serde(default)] + message_thread_id: Option, + + /// True when this message is sent inside a forum topic. + #[serde(default)] + is_topic_message: Option, } /// Telegram PhotoSize object. @@ -198,6 +207,10 @@ struct TelegramChat { /// Title for groups/channels. title: Option, + /// True when the supergroup has topics (forum mode) enabled. + #[serde(default)] + is_forum: Option, + /// Username for private chats. username: Option, } @@ -290,6 +303,10 @@ struct TelegramMessageMetadata { /// Whether this is a private (DM) chat. is_private: bool, + + /// Forum topic thread ID (for routing replies back to the correct topic). + #[serde(default, skip_serializing_if = "Option::is_none")] + message_thread_id: Option, } /// Channel configuration injected by host. @@ -680,7 +697,7 @@ impl Guest for TelegramChannel { let metadata: TelegramMessageMetadata = serde_json::from_str(&response.metadata_json) .map_err(|e| format!("Failed to parse metadata: {}", e))?; - send_response(metadata.chat_id, &response, Some(metadata.message_id)) + send_response(metadata.chat_id, &response, Some(metadata.message_id), metadata.message_thread_id) } fn on_broadcast(user_id: String, response: AgentResponse) -> Result<(), String> { @@ -688,7 +705,7 @@ impl Guest for TelegramChannel { .parse() .map_err(|e| format!("Invalid chat_id '{}': {}", user_id, e))?; - send_response(chat_id, &response, None) + send_response(chat_id, &response, None, None) } fn on_status(update: StatusUpdate) { @@ -712,11 +729,17 @@ impl Guest for TelegramChannel { match action { TelegramStatusAction::Typing => { // POST /sendChatAction with action "typing" - let payload = serde_json::json!({ + let mut payload = serde_json::json!({ "chat_id": metadata.chat_id, "action": "typing" }); + // sendChatAction requires message_thread_id even for the General + // topic (id=1), unlike sendMessage which rejects it. + if let Some(thread_id) = metadata.message_thread_id { + payload["message_thread_id"] = serde_json::Value::Number(thread_id.into()); + } + let payload_bytes = match serde_json::to_vec(&payload) { Ok(b) => b, Err(_) => return, @@ -744,7 +767,7 @@ impl Guest for TelegramChannel { TelegramStatusAction::Notify(prompt) => { // Send user-visible status updates for actionable events. if let Err(first_err) = - send_message(metadata.chat_id, &prompt, Some(metadata.message_id), None) + send_message(metadata.chat_id, &prompt, Some(metadata.message_id), None, metadata.message_thread_id) { channel_host::log( channel_host::LogLevel::Warn, @@ -754,7 +777,7 @@ impl Guest for TelegramChannel { ), ); - if let Err(retry_err) = send_message(metadata.chat_id, &prompt, None, None) { + if let Err(retry_err) = send_message(metadata.chat_id, &prompt, None, None, metadata.message_thread_id) { channel_host::log( channel_host::LogLevel::Debug, &format!( @@ -797,6 +820,15 @@ impl std::fmt::Display for SendError { } } +/// Normalize `message_thread_id` for outbound API calls. +/// +/// Telegram rejects `sendMessage` (and other send methods) when +/// `message_thread_id = 1` (the "General" topic). Return `None` in that +/// case so the field is omitted from the payload. +fn normalize_thread_id(thread_id: Option) -> Option { + thread_id.filter(|&id| id != 1) +} + /// Send a message via the Telegram Bot API. /// /// Returns the sent message_id on success. When `parse_mode` is set and @@ -807,7 +839,10 @@ fn send_message( text: &str, reply_to_message_id: Option, parse_mode: Option<&str>, + message_thread_id: Option, ) -> Result { + let message_thread_id = normalize_thread_id(message_thread_id); + let mut payload = serde_json::json!({ "chat_id": chat_id, "text": text, @@ -821,6 +856,10 @@ fn send_message( payload["parse_mode"] = serde_json::Value::String(mode.to_string()); } + if let Some(thread_id) = message_thread_id { + payload["message_thread_id"] = serde_json::Value::Number(thread_id.into()); + } + let payload_bytes = serde_json::to_vec(&payload) .map_err(|e| SendError::Other(format!("Failed to serialize payload: {}", e)))?; @@ -1036,7 +1075,10 @@ fn send_photo( mime_type: &str, data: &[u8], reply_to_message_id: Option, + message_thread_id: Option, ) -> Result<(), String> { + let message_thread_id = normalize_thread_id(message_thread_id); + if data.len() > MAX_PHOTO_SIZE { channel_host::log( channel_host::LogLevel::Info, @@ -1046,7 +1088,7 @@ fn send_photo( data.len() ), ); - return send_document(chat_id, filename, mime_type, data, reply_to_message_id); + return send_document(chat_id, filename, mime_type, data, reply_to_message_id, message_thread_id); } let boundary = format!("ironclaw-{}", channel_host::now_millis()); @@ -1056,6 +1098,9 @@ fn send_photo( if let Some(msg_id) = reply_to_message_id { write_multipart_field(&mut body, &boundary, "reply_to_message_id", &msg_id.to_string()); } + if let Some(thread_id) = message_thread_id { + write_multipart_field(&mut body, &boundary, "message_thread_id", &thread_id.to_string()); + } write_multipart_file(&mut body, &boundary, "photo", filename, mime_type, data); body.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes()); @@ -1097,7 +1142,10 @@ fn send_document( mime_type: &str, data: &[u8], reply_to_message_id: Option, + message_thread_id: Option, ) -> Result<(), String> { + let message_thread_id = normalize_thread_id(message_thread_id); + let boundary = format!("ironclaw-{}", channel_host::now_millis()); let mut body = Vec::new(); @@ -1105,6 +1153,9 @@ fn send_document( if let Some(msg_id) = reply_to_message_id { write_multipart_field(&mut body, &boundary, "reply_to_message_id", &msg_id.to_string()); } + if let Some(thread_id) = message_thread_id { + write_multipart_field(&mut body, &boundary, "message_thread_id", &thread_id.to_string()); + } write_multipart_file(&mut body, &boundary, "document", filename, mime_type, data); body.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes()); @@ -1154,10 +1205,11 @@ fn send_response( chat_id: i64, response: &AgentResponse, reply_to_message_id: Option, + message_thread_id: Option, ) -> Result<(), String> { // Send attachments first (photos/documents) for attachment in &response.attachments { - send_attachment(chat_id, attachment, reply_to_message_id)?; + send_attachment(chat_id, attachment, reply_to_message_id, message_thread_id)?; } // Skip text if empty and we already sent attachments @@ -1166,10 +1218,10 @@ fn send_response( } // Try Markdown, fall back to plain text on parse errors - match send_message(chat_id, &response.content, reply_to_message_id, Some("Markdown")) { + match send_message(chat_id, &response.content, reply_to_message_id, Some("Markdown"), message_thread_id) { Ok(_) => Ok(()), Err(SendError::ParseEntities(_)) => { - send_message(chat_id, &response.content, reply_to_message_id, None) + send_message(chat_id, &response.content, reply_to_message_id, None, message_thread_id) .map(|_| ()) .map_err(|e| format!("Plain-text retry also failed: {}", e)) } @@ -1182,6 +1234,7 @@ fn send_attachment( chat_id: i64, attachment: &Attachment, reply_to_message_id: Option, + message_thread_id: Option, ) -> Result<(), String> { if PHOTO_MIME_TYPES.contains(&attachment.mime_type.as_str()) { send_photo( @@ -1190,6 +1243,7 @@ fn send_attachment( &attachment.mime_type, &attachment.data, reply_to_message_id, + message_thread_id, ) } else { send_document( @@ -1198,6 +1252,7 @@ fn send_attachment( &attachment.mime_type, &attachment.data, reply_to_message_id, + message_thread_id, ) } } @@ -1357,6 +1412,7 @@ fn send_pairing_reply(chat_id: i64, code: &str) -> Result<(), String> { ), None, Some("Markdown"), + None, // Pairing happens in DMs, not forum topics ) .map(|_| ()) .map_err(|e| e.to_string()) @@ -1774,6 +1830,8 @@ fn handle_message(message: TelegramMessage) { } } + let bot_username = channel_host::workspace_read(BOT_USERNAME_PATH).unwrap_or_default(); + // For group chats, only respond if bot was mentioned or respond_to_all is enabled if !is_private { let respond_to_all = channel_host::workspace_read(RESPOND_TO_ALL_GROUP_PATH) @@ -1783,7 +1841,6 @@ fn handle_message(message: TelegramMessage) { if !respond_to_all { let has_command = content.starts_with('/'); - let bot_username = channel_host::workspace_read(BOT_USERNAME_PATH).unwrap_or_default(); let has_bot_mention = if bot_username.is_empty() { content.contains('@') } else { @@ -1814,11 +1871,23 @@ fn handle_message(message: TelegramMessage) { message_id: message.message_id, user_id: from.id, is_private, + message_thread_id: message.message_thread_id, }; let metadata_json = serde_json::to_string(&metadata).unwrap_or_else(|_| "{}".to_string()); - let bot_username = channel_host::workspace_read(BOT_USERNAME_PATH).unwrap_or_default(); + // Compute thread_id for forum topics: "chat_id:topic_id" to prevent + // collisions across different groups (topic IDs are only unique per chat). + // Only use message_thread_id when the chat is a forum — non-forum groups + // also carry message_thread_id for reply threads, which are not topics. + let thread_id = if message.chat.is_forum == Some(true) { + message.message_thread_id.map(|topic_id| { + format!("{}:{}", message.chat.id, topic_id) + }) + } else { + None + }; + let content_to_emit = match content_to_emit_for_agent( &content, if bot_username.is_empty() { @@ -1838,7 +1907,7 @@ fn handle_message(message: TelegramMessage) { user_id: from.id.to_string(), user_name: Some(user_name), content: content_to_emit, - thread_id: None, // Telegram doesn't have threads in the same way + thread_id, metadata_json, attachments, }); @@ -2657,4 +2726,100 @@ mod tests { // Verify the constant is 20 MB, matching the Slack channel limit assert_eq!(MAX_DOWNLOAD_SIZE_BYTES, 20 * 1024 * 1024); } + + // === Forum Topics (thread_id) tests === + + #[test] + fn test_parse_forum_message_with_thread_id() { + let json = r#"{ + "message_id": 100, + "message_thread_id": 42, + "is_topic_message": true, + "from": {"id": 1, "is_bot": false, "first_name": "A"}, + "chat": {"id": -1001234567890, "type": "supergroup", "is_forum": true}, + "text": "Hello from a topic" + }"#; + let msg: TelegramMessage = serde_json::from_str(json).unwrap(); + assert_eq!(msg.message_thread_id, Some(42)); + assert_eq!(msg.is_topic_message, Some(true)); + assert_eq!(msg.chat.is_forum, Some(true)); + } + + #[test] + fn test_parse_non_forum_message_backward_compat() { + let json = r#"{ + "message_id": 1, + "from": {"id": 1, "is_bot": false, "first_name": "A"}, + "chat": {"id": 1, "type": "private"}, + "text": "Hello" + }"#; + let msg: TelegramMessage = serde_json::from_str(json).unwrap(); + assert_eq!(msg.message_thread_id, None); + assert_eq!(msg.is_topic_message, None); + assert_eq!(msg.chat.is_forum, None); + } + + #[test] + fn test_metadata_with_message_thread_id() { + let metadata = TelegramMessageMetadata { + chat_id: -1001234567890, + message_id: 100, + user_id: 42, + is_private: false, + message_thread_id: Some(7), + }; + let json = serde_json::to_string(&metadata).unwrap(); + let parsed: TelegramMessageMetadata = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.message_thread_id, Some(7)); + } + + #[test] + fn test_metadata_backward_compat_no_thread_id() { + // Old metadata JSON without message_thread_id should deserialize with None + let json = r#"{"chat_id":123,"message_id":1,"user_id":42,"is_private":true}"#; + let metadata: TelegramMessageMetadata = serde_json::from_str(json).unwrap(); + assert_eq!(metadata.message_thread_id, None); + } + + #[test] + fn test_metadata_thread_id_not_serialized_when_none() { + let metadata = TelegramMessageMetadata { + chat_id: 123, + message_id: 1, + user_id: 42, + is_private: true, + message_thread_id: None, + }; + let json = serde_json::to_string(&metadata).unwrap(); + assert!(!json.contains("message_thread_id")); + } + + #[test] + fn test_thread_id_composition() { + // Verify "chat_id:topic_id" format for forum topics + let chat_id: i64 = -1001234567890; + let topic_id: i64 = 42; + let thread_id = format!("{}:{}", chat_id, topic_id); + assert_eq!(thread_id, "-1001234567890:42"); + } + + #[test] + fn test_normalize_thread_id_general_topic() { + // General topic (id=1) must be omitted — Telegram rejects sendMessage + // with message_thread_id=1. + assert_eq!(normalize_thread_id(Some(1)), None); + } + + #[test] + fn test_normalize_thread_id_regular_topic() { + // Non-General topics pass through unchanged + assert_eq!(normalize_thread_id(Some(42)), Some(42)); + assert_eq!(normalize_thread_id(Some(123)), Some(123)); + } + + #[test] + fn test_normalize_thread_id_none() { + // None stays None + assert_eq!(normalize_thread_id(None), None); + } } diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 5ca094e41a..4b7ed5381f 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -750,6 +750,20 @@ impl Agent { "Message details" ); + // Internal messages (e.g. job-monitor notifications) are already + // rendered text and should be forwarded directly to the user without + // entering the normal user-input pipeline (LLM/tool loop). + // The `is_internal` field and `into_internal()` setter are pub(crate), + // so external channels cannot spoof this flag. + if message.is_internal { + tracing::debug!( + message_id = %message.id, + channel = %message.channel, + "Forwarding internal message" + ); + return Ok(Some(message.content.clone())); + } + // Set message tool context for this turn (current channel and target) // For Signal, use signal_target from metadata (group:ID or phone number), // otherwise fall back to user_id diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs index a91f59a61a..9e6747f2b3 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher.rs @@ -143,6 +143,11 @@ impl Agent { JobContext::with_user(&message.user_id, "chat", "Interactive chat session"); job_ctx.http_interceptor = self.deps.http_interceptor.clone(); job_ctx.user_timezone = user_tz.name().to_string(); + job_ctx.metadata = serde_json::json!({ + "notify_channel": message.channel, + "notify_user": message.user_id, + "notify_thread_id": message.thread_id, + }); // Build system prompts once for this turn. Two variants: with tools // (normal iterations) and without (force_text final iteration). diff --git a/src/agent/heartbeat.rs b/src/agent/heartbeat.rs index 15c51b6104..77bdeadb0f 100644 --- a/src/agent/heartbeat.rs +++ b/src/agent/heartbeat.rs @@ -26,6 +26,8 @@ use std::sync::Arc; use std::time::Duration; +use chrono::TimeZone as _; +use chrono_tz::Tz; use tokio::sync::mpsc; use crate::channels::OutgoingResponse; @@ -37,7 +39,7 @@ use crate::workspace::hygiene::HygieneConfig; /// Configuration for the heartbeat runner. #[derive(Debug, Clone)] pub struct HeartbeatConfig { - /// Interval between heartbeat checks. + /// Interval between heartbeat checks (used when fire_at is not set). pub interval: Duration, /// Whether heartbeat is enabled. pub enabled: bool, @@ -47,11 +49,13 @@ pub struct HeartbeatConfig { pub notify_user_id: Option, /// Channel to notify on heartbeat findings. pub notify_channel: Option, + /// Fixed time-of-day to fire (24h). When set, interval is ignored. + pub fire_at: Option, /// Hour (0-23) when quiet hours start. pub quiet_hours_start: Option, /// Hour (0-23) when quiet hours end. pub quiet_hours_end: Option, - /// Timezone for quiet hours evaluation (IANA name). + /// Timezone for fire_at and quiet hours evaluation (IANA name). pub timezone: Option, } @@ -63,6 +67,7 @@ impl Default for HeartbeatConfig { max_failures: 3, notify_user_id: None, notify_channel: None, + fire_at: None, quiet_hours_start: None, quiet_hours_end: None, timezone: None, @@ -109,6 +114,21 @@ impl HeartbeatConfig { self.notify_channel = Some(channel.into()); self } + + /// Set a fixed time-of-day to fire (overrides interval). + pub fn with_fire_at(mut self, time: chrono::NaiveTime, tz: Option) -> Self { + self.fire_at = Some(time); + self.timezone = tz; + self + } + + /// Resolve timezone string to chrono_tz::Tz (defaults to UTC). + fn resolved_tz(&self) -> Tz { + self.timezone + .as_deref() + .and_then(crate::timezone::parse_timezone) + .unwrap_or(chrono_tz::UTC) + } } /// Result of a heartbeat check. @@ -124,6 +144,33 @@ pub enum HeartbeatResult { Failed(String), } +/// Compute how long to sleep until the next occurrence of `fire_at` in `tz`. +/// +/// If the target time today is still in the future, sleep until then. +/// Otherwise sleep until the same time tomorrow. +fn duration_until_next_fire(fire_at: chrono::NaiveTime, tz: Tz) -> Duration { + let now = chrono::Utc::now().with_timezone(&tz); + let today = now.date_naive(); + + // Try to build today's target datetime in the given timezone. + // `.earliest()` picks the first occurrence if DST creates ambiguity. + let candidate = tz.from_local_datetime(&today.and_time(fire_at)).earliest(); + + let target = match candidate { + Some(t) if t > now => t, + _ => { + // Already past (or ambiguous) — schedule for tomorrow + let tomorrow = today + chrono::Duration::days(1); + tz.from_local_datetime(&tomorrow.and_time(fire_at)) + .earliest() + .unwrap_or_else(|| now + chrono::Duration::days(1)) + } + }; + + let secs = (target - now).num_seconds().max(1) as u64; + Duration::from_secs(secs) +} + /// Heartbeat runner for proactive periodic execution. pub struct HeartbeatRunner { config: HeartbeatConfig, @@ -175,17 +222,39 @@ impl HeartbeatRunner { return; } - tracing::info!( - "Starting heartbeat loop with interval {:?}", - self.config.interval - ); + // Two scheduling modes: + // fire_at → sleep until the next occurrence (recalculated each iteration) + // interval → tokio::time::interval (drift-free, accounts for loop body time) + let mut tick_interval = if self.config.fire_at.is_none() { + let mut iv = tokio::time::interval(self.config.interval); + // Don't fire immediately on startup. + iv.tick().await; + Some(iv) + } else { + None + }; - let mut interval = tokio::time::interval(self.config.interval); - // Don't run immediately on startup - interval.tick().await; + if let Some(fire_at) = self.config.fire_at { + tracing::info!( + "Starting heartbeat loop: fire daily at {:?} {:?}", + fire_at, + self.config.timezone + ); + } else { + tracing::info!( + "Starting heartbeat loop with interval {:?}", + self.config.interval + ); + } loop { - interval.tick().await; + if let Some(fire_at) = self.config.fire_at { + let sleep_dur = duration_until_next_fire(fire_at, self.config.resolved_tz()); + tracing::info!("Next heartbeat in {:.1}h", sleep_dur.as_secs_f64() / 3600.0); + tokio::time::sleep(sleep_dur).await; + } else if let Some(ref mut iv) = tick_interval { + iv.tick().await; + } // Skip during quiet hours if self.config.is_quiet_hours() { @@ -656,4 +725,63 @@ mod tests { ) -> tokio::task::JoinHandle<()> = spawn_heartbeat; let _ = _fn_ptr; } + + // ==================== fire_at scheduling ==================== + + #[test] + fn test_default_config_has_no_fire_at() { + let config = HeartbeatConfig::default(); + assert!(config.fire_at.is_none()); + // Interval-based scheduling should be the default + assert_eq!(config.interval, Duration::from_secs(30 * 60)); + } + + #[test] + fn test_with_fire_at_builder() { + let time = chrono::NaiveTime::from_hms_opt(9, 0, 0).unwrap(); + let config = + HeartbeatConfig::default().with_fire_at(time, Some("Pacific/Auckland".to_string())); + assert_eq!(config.fire_at, Some(time)); + assert_eq!(config.timezone, Some("Pacific/Auckland".to_string())); + } + + #[test] + fn test_duration_until_next_fire_is_bounded() { + // Result must always be between 1 second and ~24 hours + let time = chrono::NaiveTime::from_hms_opt(14, 0, 0).unwrap(); + let dur = duration_until_next_fire(time, chrono_tz::UTC); + assert!(dur.as_secs() >= 1, "duration must be at least 1 second"); + assert!( + dur.as_secs() <= 86_401, + "duration must be at most ~24 hours, got {}s", + dur.as_secs() + ); + } + + #[test] + fn test_duration_until_next_fire_dst_timezone_no_panic() { + // Use a timezone with DST (US Eastern) — should never panic + let tz: Tz = "America/New_York".parse().unwrap(); + // Test a range of times including midnight boundaries + for hour in [0, 2, 3, 12, 23] { + let time = chrono::NaiveTime::from_hms_opt(hour, 30, 0).unwrap(); + let dur = duration_until_next_fire(time, tz); + assert!(dur.as_secs() >= 1); + assert!(dur.as_secs() <= 86_401); + } + } + + #[test] + fn test_resolved_tz_defaults_to_utc() { + let config = HeartbeatConfig::default(); + assert_eq!(config.resolved_tz(), chrono_tz::UTC); + } + + #[test] + fn test_resolved_tz_parses_iana() { + let time = chrono::NaiveTime::from_hms_opt(9, 0, 0).unwrap(); + let config = + HeartbeatConfig::default().with_fire_at(time, Some("Europe/London".to_string())); + assert_eq!(config.resolved_tz(), chrono_tz::Europe::London); + } } diff --git a/src/agent/job_monitor.rs b/src/agent/job_monitor.rs index b2db885222..714caeac4b 100644 --- a/src/agent/job_monitor.rs +++ b/src/agent/job_monitor.rs @@ -21,6 +21,14 @@ use uuid::Uuid; use crate::channels::IncomingMessage; use crate::channels::web::types::SseEvent; +/// Route context for forwarding job monitor events back to the user's channel. +#[derive(Debug, Clone)] +pub struct JobMonitorRoute { + pub channel: String, + pub user_id: String, + pub thread_id: Option, +} + /// Spawn a background task that watches for events from a specific job and /// injects assistant messages into the agent loop. /// @@ -35,6 +43,7 @@ pub fn spawn_job_monitor( job_id: Uuid, mut event_rx: broadcast::Receiver<(Uuid, SseEvent)>, inject_tx: mpsc::Sender, + route: JobMonitorRoute, ) -> JoinHandle<()> { let short_id = job_id.to_string()[..8].to_string(); @@ -50,11 +59,15 @@ pub fn spawn_job_monitor( match event { SseEvent::JobMessage { role, content, .. } if role == "assistant" => { - let msg = IncomingMessage::new( - "job_monitor", - "system", + let mut msg = IncomingMessage::new( + route.channel.clone(), + route.user_id.clone(), format!("[Job {}] Claude Code: {}", short_id, content), - ); + ) + .into_internal(); + if let Some(ref thread_id) = route.thread_id { + msg = msg.with_thread(thread_id.clone()); + } if inject_tx.send(msg).await.is_err() { tracing::debug!( job_id = %short_id, @@ -64,14 +77,18 @@ pub fn spawn_job_monitor( } } SseEvent::JobResult { status, .. } => { - let msg = IncomingMessage::new( - "job_monitor", - "system", + let mut msg = IncomingMessage::new( + route.channel.clone(), + route.user_id.clone(), format!( "[Job {}] Container finished (status: {})", short_id, status ), - ); + ) + .into_internal(); + if let Some(ref thread_id) = route.thread_id { + msg = msg.with_thread(thread_id.clone()); + } let _ = inject_tx.send(msg).await; tracing::debug!( job_id = %short_id, @@ -108,13 +125,21 @@ pub fn spawn_job_monitor( mod tests { use super::*; + fn test_route() -> JobMonitorRoute { + JobMonitorRoute { + channel: "cli".to_string(), + user_id: "user-1".to_string(), + thread_id: Some("thread-1".to_string()), + } + } + #[tokio::test] async fn test_monitor_forwards_assistant_messages() { let (event_tx, _) = broadcast::channel::<(Uuid, SseEvent)>(16); let (inject_tx, mut inject_rx) = mpsc::channel::(16); let job_id = Uuid::new_v4(); - let _handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx); + let _handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx, test_route()); // Send an assistant message event_tx @@ -133,9 +158,11 @@ mod tests { .unwrap() .unwrap(); - assert_eq!(msg.channel, "job_monitor"); - assert_eq!(msg.user_id, "system"); + assert_eq!(msg.channel, "cli"); + assert_eq!(msg.user_id, "user-1"); + assert_eq!(msg.thread_id, Some("thread-1".to_string())); assert!(msg.content.contains("I found a bug")); + assert!(msg.is_internal, "monitor messages must be marked internal"); } #[tokio::test] @@ -145,7 +172,7 @@ mod tests { let job_id = Uuid::new_v4(); let other_job_id = Uuid::new_v4(); - let _handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx); + let _handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx, test_route()); // Send a message for a different job event_tx @@ -174,7 +201,7 @@ mod tests { let (inject_tx, mut inject_rx) = mpsc::channel::(16); let job_id = Uuid::new_v4(); - let handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx); + let handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx, test_route()); // Send a completion event event_tx @@ -208,7 +235,7 @@ mod tests { let (inject_tx, mut inject_rx) = mpsc::channel::(16); let job_id = Uuid::new_v4(); - let _handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx); + let _handle = spawn_job_monitor(job_id, event_tx.subscribe(), inject_tx, test_route()); // Send tool use event (should be skipped) event_tx @@ -242,4 +269,28 @@ mod tests { "should have timed out, no message expected" ); } + + /// Regression test: external channels must not be able to spoof the + /// `is_internal` flag via metadata keys. A message created through + /// the normal `IncomingMessage::new` + `with_metadata` path must + /// always have `is_internal == false`, regardless of metadata content. + #[test] + fn test_external_metadata_cannot_spoof_internal_flag() { + let msg = IncomingMessage::new("wasm_channel", "attacker", "pwned").with_metadata( + serde_json::json!({ + "__internal_job_monitor": true, + "is_internal": true, + }), + ); + assert!( + !msg.is_internal, + "with_metadata must not set is_internal — only into_internal() can" + ); + } + + #[test] + fn test_into_internal_sets_flag() { + let msg = IncomingMessage::new("monitor", "system", "test").into_internal(); + assert!(msg.is_internal); + } } diff --git a/src/channels/channel.rs b/src/channels/channel.rs index 1fc76fd74f..ed8c28ff2e 100644 --- a/src/channels/channel.rs +++ b/src/channels/channel.rs @@ -83,6 +83,11 @@ pub struct IncomingMessage { pub timezone: Option, /// File or media attachments on this message. pub attachments: Vec, + /// Internal-only flag: message was generated inside the process (e.g. job + /// monitor) and must bypass the normal user-input pipeline. This field is + /// **not** settable via `with_metadata()` — only trusted code paths inside + /// the binary can set it, preventing external channels from spoofing it. + pub(crate) is_internal: bool, } impl IncomingMessage { @@ -103,6 +108,7 @@ impl IncomingMessage { metadata: serde_json::Value::Null, timezone: None, attachments: Vec::new(), + is_internal: false, } } @@ -135,6 +141,12 @@ impl IncomingMessage { self.attachments = attachments; self } + + /// Mark this message as internal (bypasses user-input pipeline). + pub(crate) fn into_internal(mut self) -> Self { + self.is_internal = true; + self + } } /// Stream of incoming messages. diff --git a/src/channels/wasm/mod.rs b/src/channels/wasm/mod.rs index 0d4a6c3f66..dba843417d 100644 --- a/src/channels/wasm/mod.rs +++ b/src/channels/wasm/mod.rs @@ -90,6 +90,7 @@ pub mod setup; pub(crate) mod signature; #[allow(dead_code)] pub(crate) mod storage; +mod telegram_host_config; mod wrapper; // Core types @@ -107,4 +108,5 @@ pub use schema::{ ChannelCapabilitiesFile, ChannelConfig, SecretSetupSchema, SetupSchema, WebhookSchema, }; pub use setup::{WasmChannelSetup, inject_channel_credentials, setup_wasm_channels}; +pub(crate) use telegram_host_config::{TELEGRAM_CHANNEL_NAME, bot_username_setting_key}; pub use wrapper::{HttpResponse, SharedWasmChannel, WasmChannel}; diff --git a/src/channels/wasm/setup.rs b/src/channels/wasm/setup.rs index b9deb5261e..9c0c3f33a4 100644 --- a/src/channels/wasm/setup.rs +++ b/src/channels/wasm/setup.rs @@ -7,8 +7,9 @@ use std::collections::HashSet; use std::sync::Arc; use crate::channels::wasm::{ - LoadedChannel, RegisteredEndpoint, SharedWasmChannel, WasmChannel, WasmChannelLoader, - WasmChannelRouter, WasmChannelRuntime, WasmChannelRuntimeConfig, create_wasm_channel_router, + LoadedChannel, RegisteredEndpoint, SharedWasmChannel, TELEGRAM_CHANNEL_NAME, WasmChannel, + WasmChannelLoader, WasmChannelRouter, WasmChannelRuntime, WasmChannelRuntimeConfig, + bot_username_setting_key, create_wasm_channel_router, }; use crate::config::Config; use crate::db::Database; @@ -48,7 +49,7 @@ pub async fn setup_wasm_channels( let mut loader = WasmChannelLoader::new( Arc::clone(&runtime), Arc::clone(&pairing_store), - settings_store, + settings_store.clone(), ); if let Some(secrets) = secrets_store { loader = loader.with_secrets_store(Arc::clone(secrets)); @@ -70,7 +71,14 @@ pub async fn setup_wasm_channels( let mut channel_names: Vec = Vec::new(); for loaded in results.loaded { - let (name, channel) = register_channel(loaded, config, secrets_store, &wasm_router).await; + let (name, channel) = register_channel( + loaded, + config, + secrets_store, + settings_store.as_ref(), + &wasm_router, + ) + .await; channel_names.push(name.clone()); channels.push((name, channel)); } @@ -104,6 +112,7 @@ async fn register_channel( loaded: LoadedChannel, config: &Config, secrets_store: &Option>, + settings_store: Option<&Arc>, wasm_router: &Arc, ) -> (String, Box) { let channel_name = loaded.name().to_string(); @@ -161,6 +170,15 @@ async fn register_channel( config_updates.insert("owner_id".to_string(), serde_json::json!(owner_id)); } + if channel_name == TELEGRAM_CHANNEL_NAME + && let Some(store) = settings_store + && let Ok(Some(serde_json::Value::String(username))) = store + .get_setting("default", &bot_username_setting_key(&channel_name)) + .await + && !username.trim().is_empty() + { + config_updates.insert("bot_username".to_string(), serde_json::json!(username)); + } // Inject channel-specific secrets into config for channels that need // credentials in API request bodies (e.g., Feishu token exchange). // The credential injection system only replaces placeholders in URLs diff --git a/src/channels/wasm/telegram_host_config.rs b/src/channels/wasm/telegram_host_config.rs new file mode 100644 index 0000000000..79c27c0bfc --- /dev/null +++ b/src/channels/wasm/telegram_host_config.rs @@ -0,0 +1,6 @@ +pub const TELEGRAM_CHANNEL_NAME: &str = "telegram"; +const TELEGRAM_BOT_USERNAME_SETTING_PREFIX: &str = "channels.wasm_channel_bot_usernames"; + +pub fn bot_username_setting_key(channel_name: &str) -> String { + format!("{TELEGRAM_BOT_USERNAME_SETTING_PREFIX}.{channel_name}") +} diff --git a/src/channels/web/handlers/chat.rs b/src/channels/web/handlers/chat.rs index 909a252cf4..5cb2b9ea1b 100644 --- a/src/channels/web/handlers/chat.rs +++ b/src/channels/web/handlers/chat.rs @@ -162,15 +162,30 @@ pub async fn chat_auth_token_handler( .await { Ok(result) => { - clear_auth_mode(&state).await; + let mut resp = ActionResponse::ok(result.message.clone()); + resp.activated = Some(result.activated); + resp.auth_url = result.auth_url.clone(); + resp.verification = result.verification.clone(); + resp.instructions = result.verification.as_ref().map(|v| v.instructions.clone()); - state.sse.broadcast(SseEvent::AuthCompleted { - extension_name: req.extension_name.clone(), - success: true, - message: result.message.clone(), - }); + if result.verification.is_some() { + state.sse.broadcast(SseEvent::AuthRequired { + extension_name: req.extension_name.clone(), + instructions: Some(result.message), + auth_url: None, + setup_url: None, + }); + } else { + clear_auth_mode(&state).await; + + state.sse.broadcast(SseEvent::AuthCompleted { + extension_name: req.extension_name.clone(), + success: true, + message: result.message, + }); + } - Ok(Json(ActionResponse::ok(result.message))) + Ok(Json(resp)) } Err(e) => { let msg = e.to_string(); diff --git a/src/channels/web/handlers/extensions.rs b/src/channels/web/handlers/extensions.rs index 3c490eac1a..855fba3ed9 100644 --- a/src/channels/web/handlers/extensions.rs +++ b/src/channels/web/handlers/extensions.rs @@ -25,34 +25,34 @@ pub async fn extensions_list_handler( .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let pairing_store = crate::pairing::PairingStore::new(); + let mut owner_bound_channels = std::collections::HashSet::new(); + for ext in &installed { + if ext.kind == crate::extensions::ExtensionKind::WasmChannel + && ext_mgr.has_wasm_channel_owner_binding(&ext.name).await + { + owner_bound_channels.insert(ext.name.clone()); + } + } let extensions = installed .into_iter() .map(|ext| { let activation_status = if ext.kind == crate::extensions::ExtensionKind::WasmChannel { - Some(if ext.activation_error.is_some() { - "failed".to_string() - } else if !ext.authenticated { - "installed".to_string() - } else if ext.active { - let has_paired = pairing_store - .read_allow_from(&ext.name) - .map(|list| !list.is_empty()) - .unwrap_or(false); - if has_paired { - "active".to_string() - } else { - "pairing".to_string() - } - } else { - "configured".to_string() - }) + let has_paired = pairing_store + .read_allow_from(&ext.name) + .map(|list| !list.is_empty()) + .unwrap_or(false); + crate::channels::web::types::classify_wasm_channel_activation( + &ext, + has_paired, + owner_bound_channels.contains(&ext.name), + ) } else if ext.kind == crate::extensions::ExtensionKind::ChannelRelay { Some(if ext.active { - "active".to_string() + crate::channels::web::types::ExtensionActivationStatus::Active } else if ext.authenticated { - "configured".to_string() + crate::channels::web::types::ExtensionActivationStatus::Configured } else { - "installed".to_string() + crate::channels::web::types::ExtensionActivationStatus::Installed }) } else { None diff --git a/src/channels/web/server.rs b/src/channels/web/server.rs index e8cb33c220..fb8c93ae23 100644 --- a/src/channels/web/server.rs +++ b/src/channels/web/server.rs @@ -1163,19 +1163,43 @@ async fn chat_auth_token_handler( .configure_token(&req.extension_name, &req.token) .await { - Ok(result) if result.activated => { - // Clear auth mode on the active thread - clear_auth_mode(&state).await; + Ok(result) => { + let mut resp = if result.verification.is_some() || result.activated { + ActionResponse::ok(result.message.clone()) + } else { + ActionResponse::fail(result.message.clone()) + }; + resp.activated = Some(result.activated); + resp.auth_url = result.auth_url.clone(); + resp.verification = result.verification.clone(); + resp.instructions = result.verification.as_ref().map(|v| v.instructions.clone()); - state.sse.broadcast(SseEvent::AuthCompleted { - extension_name: req.extension_name.clone(), - success: true, - message: result.message.clone(), - }); + if result.verification.is_some() { + state.sse.broadcast(SseEvent::AuthRequired { + extension_name: req.extension_name.clone(), + instructions: Some(result.message), + auth_url: None, + setup_url: None, + }); + } else if result.activated { + // Clear auth mode on the active thread + clear_auth_mode(&state).await; + + state.sse.broadcast(SseEvent::AuthCompleted { + extension_name: req.extension_name.clone(), + success: true, + message: result.message, + }); + } else { + state.sse.broadcast(SseEvent::AuthCompleted { + extension_name: req.extension_name.clone(), + success: false, + message: result.message, + }); + } - Ok(Json(ActionResponse::ok(result.message))) + Ok(Json(resp)) } - Ok(result) => Ok(Json(ActionResponse::fail(result.message))), Err(e) => { let msg = e.to_string(); // Re-emit auth_required for retry on validation errors @@ -1818,29 +1842,34 @@ async fn extensions_list_handler( .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let pairing_store = crate::pairing::PairingStore::new(); + let mut owner_bound_channels = std::collections::HashSet::new(); + for ext in &installed { + if ext.kind == crate::extensions::ExtensionKind::WasmChannel + && ext_mgr.has_wasm_channel_owner_binding(&ext.name).await + { + owner_bound_channels.insert(ext.name.clone()); + } + } let extensions = installed .into_iter() .map(|ext| { let activation_status = if ext.kind == crate::extensions::ExtensionKind::WasmChannel { - Some(if ext.activation_error.is_some() { - "failed".to_string() - } else if !ext.authenticated { - // No credentials configured yet. - "installed".to_string() - } else if ext.active { - // Check pairing status for active channels. - let has_paired = pairing_store - .read_allow_from(&ext.name) - .map(|list| !list.is_empty()) - .unwrap_or(false); - if has_paired { - "active".to_string() - } else { - "pairing".to_string() - } + let has_paired = pairing_store + .read_allow_from(&ext.name) + .map(|list| !list.is_empty()) + .unwrap_or(false); + crate::channels::web::types::classify_wasm_channel_activation( + &ext, + has_paired, + owner_bound_channels.contains(&ext.name), + ) + } else if ext.kind == crate::extensions::ExtensionKind::ChannelRelay { + Some(if ext.active { + ExtensionActivationStatus::Active + } else if ext.authenticated { + ExtensionActivationStatus::Configured } else { - // Authenticated but not yet active. - "configured".to_string() + ExtensionActivationStatus::Installed }) } else { None @@ -2205,20 +2234,31 @@ async fn extensions_setup_submit_handler( match ext_mgr.configure(&name, &req.secrets).await { Ok(result) => { - // Broadcast completion status so chat UI can dismiss success cases while - // leaving failed auth/configuration flows visible for correction. - state.sse.broadcast(SseEvent::AuthCompleted { - extension_name: name.clone(), - success: result.activated, - message: result.message.clone(), - }); - let mut resp = if result.activated { + let mut resp = if result.verification.is_some() || result.activated { ActionResponse::ok(result.message) } else { ActionResponse::fail(result.message) }; resp.activated = Some(result.activated); - resp.auth_url = result.auth_url; + resp.auth_url = result.auth_url.clone(); + resp.verification = result.verification.clone(); + resp.instructions = result.verification.as_ref().map(|v| v.instructions.clone()); + if result.verification.is_some() { + state.sse.broadcast(SseEvent::AuthRequired { + extension_name: name.clone(), + instructions: resp.instructions.clone(), + auth_url: None, + setup_url: None, + }); + } else { + // Broadcast auth_completed so the chat UI can dismiss any in-progress + // auth card or setup modal that was triggered by tool_auth/tool_activate. + state.sse.broadcast(SseEvent::AuthCompleted { + extension_name: name.clone(), + success: result.activated, + message: resp.message.clone(), + }); + } Ok(Json(resp)) } Err(e) => Ok(Json(ActionResponse::fail(e.to_string()))), @@ -2743,7 +2783,11 @@ struct GatewayStatusResponse { #[cfg(test)] mod tests { use super::*; + use crate::channels::web::types::{ + ExtensionActivationStatus, classify_wasm_channel_activation, + }; use crate::cli::oauth_defaults; + use crate::extensions::{ExtensionKind, InstalledExtension}; use crate::testing::credentials::TEST_GATEWAY_CRYPTO_KEY; #[test] @@ -2822,6 +2866,85 @@ mod tests { assert!(turns.is_empty()); } + #[test] + fn test_wasm_channel_activation_status_owner_bound_counts_as_active() -> Result<(), String> { + let ext = InstalledExtension { + name: "telegram".to_string(), + kind: ExtensionKind::WasmChannel, + display_name: Some("Telegram".to_string()), + description: None, + url: None, + authenticated: true, + active: true, + tools: Vec::new(), + needs_setup: true, + has_auth: false, + installed: true, + activation_error: None, + version: None, + }; + + let owner_bound = classify_wasm_channel_activation(&ext, false, true); + if owner_bound != Some(ExtensionActivationStatus::Active) { + return Err(format!( + "owner-bound channel should be active, got {:?}", + owner_bound + )); + } + + let unbound = classify_wasm_channel_activation(&ext, false, false); + if unbound != Some(ExtensionActivationStatus::Pairing) { + return Err(format!( + "unbound channel should be pairing, got {:?}", + unbound + )); + } + + Ok(()) + } + + #[test] + fn test_channel_relay_activation_status_is_preserved() -> Result<(), String> { + let relay = InstalledExtension { + name: "signal".to_string(), + kind: ExtensionKind::ChannelRelay, + display_name: Some("Signal".to_string()), + description: None, + url: None, + authenticated: true, + active: false, + tools: Vec::new(), + needs_setup: true, + has_auth: false, + installed: true, + activation_error: None, + version: None, + }; + + let status = if relay.kind == crate::extensions::ExtensionKind::WasmChannel { + classify_wasm_channel_activation(&relay, false, false) + } else if relay.kind == crate::extensions::ExtensionKind::ChannelRelay { + Some(if relay.active { + ExtensionActivationStatus::Active + } else if relay.authenticated { + ExtensionActivationStatus::Configured + } else { + ExtensionActivationStatus::Installed + }) + } else { + None + }; + + if status != Some(ExtensionActivationStatus::Configured) { + return Err(format!( + "channel relay should retain configured status, got {:?}", + status + )); + } + + Ok(()) + } + // --- OAuth callback handler tests --- /// Build a minimal `GatewayState` for testing the OAuth callback handler. diff --git a/src/channels/web/static/app.js b/src/channels/web/static/app.js index d32968a9a3..127c18fa0c 100644 --- a/src/channels/web/static/app.js +++ b/src/channels/web/static/app.js @@ -2723,6 +2723,13 @@ function renderConfigureModal(name, secrets) { header.textContent = I18n.t('config.title', { name: name }); modal.appendChild(header); + if (name === 'telegram') { + const hint = document.createElement('div'); + hint.className = 'configure-hint'; + hint.textContent = I18n.t('config.telegramOwnerHint'); + modal.appendChild(hint); + } + const form = document.createElement('div'); form.className = 'configure-form'; @@ -2796,6 +2803,46 @@ function renderConfigureModal(name, secrets) { if (fields.length > 0) fields[0].input.focus(); } +function renderTelegramVerificationChallenge(overlay, verification) { + if (!overlay || !verification) return; + const modal = overlay.querySelector('.configure-modal'); + if (!modal) return; + + let panel = modal.querySelector('.configure-verification'); + if (!panel) { + panel = document.createElement('div'); + panel.className = 'configure-verification'; + modal.insertBefore(panel, modal.querySelector('.configure-actions')); + } + + panel.innerHTML = ''; + + const title = document.createElement('div'); + title.className = 'configure-verification-title'; + title.textContent = I18n.t('config.telegramChallengeTitle'); + panel.appendChild(title); + + const instructions = document.createElement('div'); + instructions.className = 'configure-verification-instructions'; + instructions.textContent = verification.instructions; + panel.appendChild(instructions); + + const code = document.createElement('code'); + code.className = 'configure-verification-code'; + code.textContent = verification.code; + panel.appendChild(code); + + if (verification.deep_link) { + const link = document.createElement('a'); + link.className = 'configure-verification-link'; + link.href = verification.deep_link; + link.target = '_blank'; + link.rel = 'noreferrer noopener'; + link.textContent = I18n.t('config.telegramOpenBot'); + panel.appendChild(link); + } +} + function submitConfigureModal(name, fields) { const secrets = {}; for (const f of fields) { @@ -2808,6 +2855,10 @@ function submitConfigureModal(name, fields) { const overlay = getConfigureOverlay(name) || document.querySelector('.configure-overlay'); var btns = overlay ? overlay.querySelectorAll('.configure-actions button') : []; btns.forEach(function(b) { b.disabled = true; }); + if (overlay && name === 'telegram') { + const submitBtn = overlay.querySelector('.configure-actions button.btn-ext.activate'); + if (submitBtn) submitBtn.textContent = I18n.t('config.telegramOwnerWaiting'); + } apiFetch('/api/extensions/' + encodeURIComponent(name) + '/setup', { method: 'POST', @@ -2815,6 +2866,16 @@ function submitConfigureModal(name, fields) { }) .then((res) => { if (res.success) { + if (res.verification && name === 'telegram') { + btns.forEach(function(b) { b.disabled = false; }); + renderTelegramVerificationChallenge(overlay, res.verification); + fields.forEach(function(f) { f.input.value = ''; }); + const submitBtn = overlay.querySelector('.configure-actions button.btn-ext.activate'); + if (submitBtn) submitBtn.textContent = I18n.t('config.telegramVerifyOwner'); + showToast(res.message || res.verification.instructions, 'info'); + return; + } + closeConfigureModal(); if (res.auth_url) { showAuthCard({ @@ -2830,11 +2891,29 @@ function submitConfigureModal(name, fields) { } else { // Keep modal open so the user can correct their input and retry. btns.forEach(function(b) { b.disabled = false; }); + if (name === 'telegram') { + const submitBtn = overlay && overlay.querySelector('.configure-actions button.btn-ext.activate'); + const hasVerification = overlay && overlay.querySelector('.configure-verification'); + if (submitBtn) { + submitBtn.textContent = hasVerification + ? I18n.t('config.telegramVerifyOwner') + : I18n.t('config.save'); + } + } showToast(res.message || 'Configuration failed', 'error'); } }) .catch((err) => { btns.forEach(function(b) { b.disabled = false; }); + if (name === 'telegram') { + const submitBtn = overlay && overlay.querySelector('.configure-actions button.btn-ext.activate'); + const hasVerification = overlay && overlay.querySelector('.configure-verification'); + if (submitBtn) { + submitBtn.textContent = hasVerification + ? I18n.t('config.telegramVerifyOwner') + : I18n.t('config.save'); + } + } showToast('Configuration failed: ' + err.message, 'error'); }); } diff --git a/src/channels/web/static/i18n/en.js b/src/channels/web/static/i18n/en.js index b637f14484..42e996da0a 100644 --- a/src/channels/web/static/i18n/en.js +++ b/src/channels/web/static/i18n/en.js @@ -342,6 +342,11 @@ I18n.register('en', { // Configure 'config.title': 'Configure {name}', + 'config.telegramOwnerHint': 'After saving, IronClaw will show a one-time code. Send `/start CODE` to your bot in Telegram, then click Verify owner.', + 'config.telegramChallengeTitle': 'Telegram owner verification', + 'config.telegramOwnerWaiting': 'Waiting for Telegram owner verification...', + 'config.telegramVerifyOwner': 'Verify owner', + 'config.telegramOpenBot': 'Open bot in Telegram', 'config.optional': ' (optional)', 'config.alreadySet': '(already set — leave empty to keep)', 'config.alreadyConfigured': 'Already configured', diff --git a/src/channels/web/static/style.css b/src/channels/web/static/style.css index 0ba5766f1d..44fd91762f 100644 --- a/src/channels/web/static/style.css +++ b/src/channels/web/static/style.css @@ -2896,6 +2896,62 @@ body { color: var(--text-primary); } +.configure-hint { + margin: 0 0 16px 0; + padding: 10px 12px; + border-radius: 8px; + background: var(--bg-secondary); + border: 1px solid var(--border); + color: var(--text-secondary); + font-size: 13px; + line-height: 1.5; +} + +.configure-verification { + display: flex; + flex-direction: column; + gap: 10px; + margin: 16px 0 0 0; + padding: 12px; + border-radius: 8px; + background: var(--bg-secondary); + border: 1px solid var(--border); +} + +.configure-verification-title { + font-size: 13px; + font-weight: 600; + color: var(--text-primary); +} + +.configure-verification-instructions { + font-size: 13px; + line-height: 1.5; + color: var(--text-secondary); +} + +.configure-verification-code { + display: inline-block; + width: fit-content; + padding: 6px 10px; + border-radius: 6px; + background: rgba(255, 255, 255, 0.06); + border: 1px solid var(--border); + color: var(--text-primary); + font-size: 13px; +} + +.configure-verification-link { + width: fit-content; + color: var(--accent, var(--text-link, #4ea3ff)); + font-size: 13px; + text-decoration: none; +} + +.configure-verification-link:hover { + text-decoration: underline; +} + .configure-form { display: flex; flex-direction: column; diff --git a/src/channels/web/types.rs b/src/channels/web/types.rs index 129a70717c..3fad9f3525 100644 --- a/src/channels/web/types.rs +++ b/src/channels/web/types.rs @@ -410,6 +410,40 @@ pub struct TransitionInfo { // --- Extensions --- +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ExtensionActivationStatus { + Installed, + Configured, + Pairing, + Active, + Failed, +} + +pub fn classify_wasm_channel_activation( + ext: &crate::extensions::InstalledExtension, + has_paired: bool, + has_owner_binding: bool, +) -> Option { + if ext.kind != crate::extensions::ExtensionKind::WasmChannel { + return None; + } + + Some(if ext.activation_error.is_some() { + ExtensionActivationStatus::Failed + } else if !ext.authenticated { + ExtensionActivationStatus::Installed + } else if ext.active { + if has_paired || has_owner_binding { + ExtensionActivationStatus::Active + } else { + ExtensionActivationStatus::Pairing + } + } else { + ExtensionActivationStatus::Configured + }) +} + #[derive(Debug, Serialize)] pub struct ExtensionInfo { pub name: String, @@ -428,9 +462,9 @@ pub struct ExtensionInfo { /// Whether this extension has an auth configuration (OAuth or manual token). #[serde(default)] pub has_auth: bool, - /// WASM channel activation status: "installed", "configured", "active", "failed". + /// WASM channel activation status. #[serde(skip_serializing_if = "Option::is_none")] - pub activation_status: Option, + pub activation_status: Option, /// Human-readable error when activation_status is "failed". #[serde(skip_serializing_if = "Option::is_none")] pub activation_error: Option, @@ -503,6 +537,9 @@ pub struct ActionResponse { /// Whether the channel was successfully activated after setup. #[serde(skip_serializing_if = "Option::is_none")] pub activated: Option, + /// Pending manual verification challenge (for Telegram owner binding, etc.). + #[serde(skip_serializing_if = "Option::is_none")] + pub verification: Option, } impl ActionResponse { @@ -514,6 +551,7 @@ impl ActionResponse { awaiting_token: None, instructions: None, activated: None, + verification: None, } } @@ -525,6 +563,7 @@ impl ActionResponse { awaiting_token: None, instructions: None, activated: None, + verification: None, } } } diff --git a/src/channels/web/ws.rs b/src/channels/web/ws.rs index 7287902e2f..7bf50e52a9 100644 --- a/src/channels/web/ws.rs +++ b/src/channels/web/ws.rs @@ -265,14 +265,25 @@ async fn handle_client_message( if let Some(ref ext_mgr) = state.extension_manager { match ext_mgr.configure_token(&extension_name, &token).await { Ok(result) => { - crate::channels::web::server::clear_auth_mode(state).await; - state - .sse - .broadcast(crate::channels::web::types::SseEvent::AuthCompleted { - extension_name, - success: true, - message: result.message, - }); + if result.verification.is_some() { + state.sse.broadcast( + crate::channels::web::types::SseEvent::AuthRequired { + extension_name: extension_name.clone(), + instructions: Some(result.message), + auth_url: None, + setup_url: None, + }, + ); + } else { + crate::channels::web::server::clear_auth_mode(state).await; + state.sse.broadcast( + crate::channels::web::types::SseEvent::AuthCompleted { + extension_name, + success: true, + message: result.message, + }, + ); + } } Err(e) => { let msg = format!("Auth failed: {}", e); diff --git a/src/config/builder.rs b/src/config/builder.rs index 90bbb1852f..088db90c63 100644 --- a/src/config/builder.rs +++ b/src/config/builder.rs @@ -32,13 +32,16 @@ impl Default for BuilderModeConfig { } impl BuilderModeConfig { - pub(crate) fn resolve() -> Result { + pub(crate) fn resolve(settings: &crate::settings::Settings) -> Result { + let bs = &settings.builder; Ok(Self { - enabled: parse_bool_env("BUILDER_ENABLED", true)?, - build_dir: optional_env("BUILDER_DIR")?.map(PathBuf::from), - max_iterations: parse_optional_env("BUILDER_MAX_ITERATIONS", 20)?, - timeout_secs: parse_optional_env("BUILDER_TIMEOUT_SECS", 600)?, - auto_register: parse_bool_env("BUILDER_AUTO_REGISTER", true)?, + enabled: parse_bool_env("BUILDER_ENABLED", bs.enabled)?, + build_dir: optional_env("BUILDER_DIR")? + .map(PathBuf::from) + .or_else(|| bs.build_dir.clone()), + max_iterations: parse_optional_env("BUILDER_MAX_ITERATIONS", bs.max_iterations)?, + timeout_secs: parse_optional_env("BUILDER_TIMEOUT_SECS", bs.timeout_secs)?, + auto_register: parse_bool_env("BUILDER_AUTO_REGISTER", bs.auto_register)?, }) } @@ -56,3 +59,36 @@ impl BuilderModeConfig { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::helpers::ENV_MUTEX; + use crate::settings::Settings; + + #[test] + fn resolve_falls_back_to_settings() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + let mut settings = Settings::default(); + settings.builder.max_iterations = 99; + settings.builder.auto_register = false; + + let cfg = BuilderModeConfig::resolve(&settings).expect("resolve"); + assert_eq!(cfg.max_iterations, 99); + assert!(!cfg.auto_register); + } + + #[test] + fn env_overrides_settings() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + let mut settings = Settings::default(); + settings.builder.timeout_secs = 123; + + // SAFETY: Under ENV_MUTEX, no concurrent env access. + unsafe { std::env::set_var("BUILDER_TIMEOUT_SECS", "3") }; + let cfg = BuilderModeConfig::resolve(&settings).expect("resolve"); + unsafe { std::env::remove_var("BUILDER_TIMEOUT_SECS") }; + + assert_eq!(cfg.timeout_secs, 3); + } +} diff --git a/src/config/database.rs b/src/config/database.rs index 44abc09b26..55d8baea7f 100644 --- a/src/config/database.rs +++ b/src/config/database.rs @@ -170,6 +170,40 @@ impl DatabaseConfig { }) } + /// Create a config from a raw PostgreSQL URL (for wizard/testing). + pub fn from_postgres_url(url: &str, pool_size: usize) -> Self { + Self { + backend: DatabaseBackend::Postgres, + url: SecretString::from(url.to_string()), + pool_size, + ssl_mode: SslMode::from_env(), + libsql_path: None, + libsql_url: None, + libsql_auth_token: None, + } + } + + /// Create a config for a libSQL database (for wizard/testing). + /// + /// Empty strings for `turso_url` and `turso_token` are treated as `None`. + pub fn from_libsql_path( + path: &str, + turso_url: Option<&str>, + turso_token: Option<&str>, + ) -> Self { + let turso_url = turso_url.filter(|s| !s.is_empty()); + let turso_token = turso_token.filter(|s| !s.is_empty()); + Self { + backend: DatabaseBackend::LibSql, + url: SecretString::from("unused://libsql".to_string()), + pool_size: 1, + ssl_mode: SslMode::default(), + libsql_path: Some(PathBuf::from(path)), + libsql_url: turso_url.map(String::from), + libsql_auth_token: turso_token.map(|t| SecretString::from(t.to_string())), + } + } + /// Get the database URL (exposes the secret). pub fn url(&self) -> &str { self.url.expose_secret() diff --git a/src/config/heartbeat.rs b/src/config/heartbeat.rs index 3de1da6632..1dd456d7fa 100644 --- a/src/config/heartbeat.rs +++ b/src/config/heartbeat.rs @@ -7,17 +7,19 @@ use crate::settings::Settings; pub struct HeartbeatConfig { /// Whether heartbeat is enabled. pub enabled: bool, - /// Interval between heartbeat checks in seconds. + /// Interval between heartbeat checks in seconds (used when fire_at is not set). pub interval_secs: u64, /// Channel to notify on heartbeat findings. pub notify_channel: Option, /// User ID to notify on heartbeat findings. pub notify_user: Option, + /// Fixed time-of-day to fire (HH:MM, 24h). When set, interval_secs is ignored. + pub fire_at: Option, /// Hour (0-23) when quiet hours start. pub quiet_hours_start: Option, /// Hour (0-23) when quiet hours end. pub quiet_hours_end: Option, - /// Timezone for quiet hours evaluation (IANA name). + /// Timezone for fire_at and quiet hours evaluation (IANA name). pub timezone: Option, } @@ -28,6 +30,7 @@ impl Default for HeartbeatConfig { interval_secs: 1800, // 30 minutes notify_channel: None, notify_user: None, + fire_at: None, quiet_hours_start: None, quiet_hours_end: None, timezone: None, @@ -37,6 +40,19 @@ impl Default for HeartbeatConfig { impl HeartbeatConfig { pub(crate) fn resolve(settings: &Settings) -> Result { + let fire_at_str = + optional_env("HEARTBEAT_FIRE_AT")?.or_else(|| settings.heartbeat.fire_at.clone()); + let fire_at = fire_at_str + .map(|s| { + chrono::NaiveTime::parse_from_str(&s, "%H:%M").map_err(|e| { + ConfigError::InvalidValue { + key: "HEARTBEAT_FIRE_AT".to_string(), + message: format!("must be HH:MM (24h), e.g. '14:00': {e}"), + } + }) + }) + .transpose()?; + Ok(Self { enabled: parse_bool_env("HEARTBEAT_ENABLED", settings.heartbeat.enabled)?, interval_secs: parse_optional_env( @@ -47,6 +63,7 @@ impl HeartbeatConfig { .or_else(|| settings.heartbeat.notify_channel.clone()), notify_user: optional_env("HEARTBEAT_NOTIFY_USER")? .or_else(|| settings.heartbeat.notify_user.clone()), + fire_at, quiet_hours_start: parse_option_env::("HEARTBEAT_QUIET_START")? .or(settings.heartbeat.quiet_hours_start) .map(|h| { diff --git a/src/config/llm.rs b/src/config/llm.rs index 31b8ff4c2b..4ad2439928 100644 --- a/src/config/llm.rs +++ b/src/config/llm.rs @@ -9,7 +9,6 @@ use crate::llm::config::*; use crate::llm::registry::{ProviderProtocol, ProviderRegistry}; use crate::llm::session::SessionConfig; use crate::settings::Settings; - impl LlmConfig { /// Create a test-friendly config without reading env vars. #[cfg(feature = "libsql")] @@ -241,8 +240,30 @@ impl LlmConfig { ) }; - // Resolve API key from env - let api_key = if let Some(env_var) = api_key_env { + // Codex auth.json override: when LLM_USE_CODEX_AUTH=true, + // credentials from the Codex CLI's auth.json take highest priority + // (over env vars AND secrets store). In ChatGPT mode, the base URL + // is also overridden to the private ChatGPT backend endpoint. + let mut codex_base_url_override: Option = None; + let codex_creds = if parse_optional_env("LLM_USE_CODEX_AUTH", false)? { + let path = optional_env("CODEX_AUTH_PATH")? + .map(std::path::PathBuf::from) + .unwrap_or_else(crate::llm::codex_auth::default_codex_auth_path); + crate::llm::codex_auth::load_codex_credentials(&path) + } else { + None + }; + + let codex_refresh_token = codex_creds.as_ref().and_then(|c| c.refresh_token.clone()); + let codex_auth_path = codex_creds.as_ref().and_then(|c| c.auth_path.clone()); + + let api_key = if let Some(creds) = codex_creds { + if creds.is_chatgpt_mode { + codex_base_url_override = Some(creds.base_url().to_string()); + } + Some(creds.token) + } else if let Some(env_var) = api_key_env { + // Resolve API key from env (including secrets store overlay) optional_env(env_var)?.map(SecretString::from) } else { None @@ -259,22 +280,28 @@ impl LlmConfig { } } - // Resolve base URL: env var > settings (backward compat) > registry default - let base_url = if let Some(env_var) = base_url_env { - optional_env(env_var)? - } else { - None - } - .or_else(|| { - // Backward compat: check legacy settings fields - match backend { - "ollama" => settings.ollama_base_url.clone(), - "openai_compatible" | "openrouter" => settings.openai_compatible_base_url.clone(), - _ => None, - } - }) - .or_else(|| default_base_url.map(String::from)) - .unwrap_or_default(); + // Resolve base URL: codex override > env var > settings (backward compat) > registry default + let is_codex_chatgpt = codex_base_url_override.is_some(); + let base_url = codex_base_url_override + .or_else(|| { + if let Some(env_var) = base_url_env { + optional_env(env_var).ok().flatten() + } else { + None + } + }) + .or_else(|| { + // Backward compat: check legacy settings fields + match backend { + "ollama" => settings.ollama_base_url.clone(), + "openai_compatible" | "openrouter" => { + settings.openai_compatible_base_url.clone() + } + _ => None, + } + }) + .or_else(|| default_base_url.map(String::from)) + .unwrap_or_default(); if base_url_required && base_url.is_empty() @@ -340,6 +367,9 @@ impl LlmConfig { model, extra_headers, oauth_token, + is_codex_chatgpt, + refresh_token: codex_refresh_token, + auth_path: codex_auth_path, cache_retention, unsupported_params, }) diff --git a/src/config/mod.rs b/src/config/mod.rs index 5299796391..1c81329e11 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -317,15 +317,15 @@ impl Config { channels: ChannelsConfig::resolve(settings, tunnel.is_enabled())?, tunnel, agent: AgentConfig::resolve(settings)?, - safety: resolve_safety_config()?, - wasm: WasmConfig::resolve()?, + safety: resolve_safety_config(settings)?, + wasm: WasmConfig::resolve(settings)?, secrets: SecretsConfig::resolve().await?, - builder: BuilderModeConfig::resolve()?, + builder: BuilderModeConfig::resolve(settings)?, heartbeat: HeartbeatConfig::resolve(settings)?, hygiene: HygieneConfig::resolve()?, routines: RoutineConfig::resolve()?, - sandbox: SandboxModeConfig::resolve()?, - claude_code: ClaudeCodeConfig::resolve()?, + sandbox: SandboxModeConfig::resolve(settings)?, + claude_code: ClaudeCodeConfig::resolve(settings)?, skills: SkillsConfig::resolve()?, transcription: TranscriptionConfig::resolve(settings)?, search: WorkspaceSearchConfig::resolve()?, diff --git a/src/config/safety.rs b/src/config/safety.rs index f804d6ad7e..ff9e900a51 100644 --- a/src/config/safety.rs +++ b/src/config/safety.rs @@ -3,9 +3,48 @@ use crate::error::ConfigError; pub use ironclaw_safety::SafetyConfig; -pub(crate) fn resolve_safety_config() -> Result { +pub(crate) fn resolve_safety_config( + settings: &crate::settings::Settings, +) -> Result { + let ss = &settings.safety; Ok(SafetyConfig { - max_output_length: parse_optional_env("SAFETY_MAX_OUTPUT_LENGTH", 100_000)?, - injection_check_enabled: parse_bool_env("SAFETY_INJECTION_CHECK_ENABLED", true)?, + max_output_length: parse_optional_env("SAFETY_MAX_OUTPUT_LENGTH", ss.max_output_length)?, + injection_check_enabled: parse_bool_env( + "SAFETY_INJECTION_CHECK_ENABLED", + ss.injection_check_enabled, + )?, }) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::helpers::ENV_MUTEX; + use crate::settings::Settings; + + #[test] + fn resolve_falls_back_to_settings() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + let mut settings = Settings::default(); + settings.safety.max_output_length = 42; + settings.safety.injection_check_enabled = false; + + let cfg = resolve_safety_config(&settings).expect("resolve"); + assert_eq!(cfg.max_output_length, 42); + assert!(!cfg.injection_check_enabled); + } + + #[test] + fn env_overrides_settings() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + let mut settings = Settings::default(); + settings.safety.max_output_length = 42; + + // SAFETY: Under ENV_MUTEX, no concurrent env access. + unsafe { std::env::set_var("SAFETY_MAX_OUTPUT_LENGTH", "7") }; + let cfg = resolve_safety_config(&settings).expect("resolve"); + unsafe { std::env::remove_var("SAFETY_MAX_OUTPUT_LENGTH") }; + + assert_eq!(cfg.max_output_length, 7); + } +} diff --git a/src/config/sandbox.rs b/src/config/sandbox.rs index e9b7ca7684..8c0eb689ae 100644 --- a/src/config/sandbox.rs +++ b/src/config/sandbox.rs @@ -52,11 +52,20 @@ impl Default for SandboxModeConfig { } impl SandboxModeConfig { - pub(crate) fn resolve() -> Result { + pub(crate) fn resolve(settings: &crate::settings::Settings) -> Result { + let ss = &settings.sandbox; + let extra_domains = optional_env("SANDBOX_EXTRA_DOMAINS")? .map(|s| s.split(',').map(|d| d.trim().to_string()).collect()) - .unwrap_or_default(); + .unwrap_or_else(|| { + if ss.extra_allowed_domains.is_empty() { + Vec::new() + } else { + ss.extra_allowed_domains.clone() + } + }); + // reaper/orphan fields have no Settings counterpart — env > default only. let reaper_interval_secs: u64 = parse_optional_env("SANDBOX_REAPER_INTERVAL_SECS", 300)?; let orphan_threshold_secs: u64 = parse_optional_env("SANDBOX_ORPHAN_THRESHOLD_SECS", 600)?; @@ -76,14 +85,15 @@ impl SandboxModeConfig { } Ok(Self { - enabled: parse_bool_env("SANDBOX_ENABLED", true)?, - policy: parse_string_env("SANDBOX_POLICY", "readonly")?, + enabled: parse_bool_env("SANDBOX_ENABLED", ss.enabled)?, + policy: parse_string_env("SANDBOX_POLICY", ss.policy.clone())?, + // allow_full_access has no Settings counterpart — env > default only. allow_full_access: parse_bool_env("SANDBOX_ALLOW_FULL_ACCESS", false)?, - timeout_secs: parse_optional_env("SANDBOX_TIMEOUT_SECS", 120)?, - memory_limit_mb: parse_optional_env("SANDBOX_MEMORY_LIMIT_MB", 2048)?, - cpu_shares: parse_optional_env("SANDBOX_CPU_SHARES", 1024)?, - image: parse_string_env("SANDBOX_IMAGE", "ironclaw-worker:latest")?, - auto_pull_image: parse_bool_env("SANDBOX_AUTO_PULL", true)?, + timeout_secs: parse_optional_env("SANDBOX_TIMEOUT_SECS", ss.timeout_secs)?, + memory_limit_mb: parse_optional_env("SANDBOX_MEMORY_LIMIT_MB", ss.memory_limit_mb)?, + cpu_shares: parse_optional_env("SANDBOX_CPU_SHARES", ss.cpu_shares)?, + image: parse_string_env("SANDBOX_IMAGE", ss.image.clone())?, + auto_pull_image: parse_bool_env("SANDBOX_AUTO_PULL", ss.auto_pull_image)?, extra_allowed_domains: extra_domains, reaper_interval_secs, orphan_threshold_secs, @@ -200,7 +210,7 @@ impl ClaudeCodeConfig { /// Load from environment variables only (used inside containers where /// there is no database or full config). pub fn from_env() -> Self { - match Self::resolve() { + match Self::resolve_env_only() { Ok(c) => c, Err(e) => { tracing::warn!("Failed to resolve ClaudeCodeConfig: {e}, using defaults"); @@ -253,7 +263,33 @@ impl ClaudeCodeConfig { None } - pub(crate) fn resolve() -> Result { + pub(crate) fn resolve(settings: &crate::settings::Settings) -> Result { + let defaults = Self::default(); + Ok(Self { + // Use settings.sandbox.claude_code_enabled as fallback (written by setup wizard). + enabled: parse_bool_env("CLAUDE_CODE_ENABLED", settings.sandbox.claude_code_enabled)?, + config_dir: optional_env("CLAUDE_CONFIG_DIR")? + .map(std::path::PathBuf::from) + .unwrap_or(defaults.config_dir), + model: parse_string_env("CLAUDE_CODE_MODEL", defaults.model)?, + max_turns: parse_optional_env("CLAUDE_CODE_MAX_TURNS", defaults.max_turns)?, + memory_limit_mb: parse_optional_env( + "CLAUDE_CODE_MEMORY_LIMIT_MB", + defaults.memory_limit_mb, + )?, + allowed_tools: optional_env("CLAUDE_CODE_ALLOWED_TOOLS")? + .map(|s| { + s.split(',') + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()) + .collect() + }) + .unwrap_or(defaults.allowed_tools), + }) + } + + /// Resolve from env vars only, no Settings. Used inside containers. + fn resolve_env_only() -> Result { let defaults = Self::default(); Ok(Self { enabled: parse_bool_env("CLAUDE_CODE_ENABLED", defaults.enabled)?, @@ -554,6 +590,80 @@ mod tests { ); } + // ── Settings fallback tests ────────────────────────────────────── + + #[test] + fn sandbox_resolve_falls_back_to_settings() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let mut settings = crate::settings::Settings::default(); + settings.sandbox.cpu_shares = 99; + settings.sandbox.auto_pull_image = false; + settings.sandbox.enabled = false; + + let cfg = SandboxModeConfig::resolve(&settings).expect("resolve"); + assert!(!cfg.enabled); + assert_eq!(cfg.cpu_shares, 99); + assert!(!cfg.auto_pull_image); + } + + #[test] + fn sandbox_env_overrides_settings() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let mut settings = crate::settings::Settings::default(); + settings.sandbox.timeout_secs = 999; + + // SAFETY: Under ENV_MUTEX, no concurrent env access. + unsafe { std::env::set_var("SANDBOX_TIMEOUT_SECS", "5") }; + let cfg = SandboxModeConfig::resolve(&settings).expect("resolve"); + unsafe { std::env::remove_var("SANDBOX_TIMEOUT_SECS") }; + + assert_eq!(cfg.timeout_secs, 5); + } + + // ── ClaudeCodeConfig settings fallback tests ──────────────────── + + #[test] + fn claude_code_resolve_uses_settings_enabled() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let mut settings = crate::settings::Settings::default(); + settings.sandbox.claude_code_enabled = true; + + let cfg = ClaudeCodeConfig::resolve(&settings).expect("resolve"); + assert!(cfg.enabled); + } + + #[test] + fn claude_code_resolve_defaults_disabled() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let settings = crate::settings::Settings::default(); + let cfg = ClaudeCodeConfig::resolve(&settings).expect("resolve"); + assert!(!cfg.enabled); + } + + #[test] + fn claude_code_env_overrides_settings() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let mut settings = crate::settings::Settings::default(); + settings.sandbox.claude_code_enabled = true; + + // SAFETY: Under ENV_MUTEX, no concurrent env access. + unsafe { std::env::set_var("CLAUDE_CODE_ENABLED", "false") }; + let cfg = ClaudeCodeConfig::resolve(&settings).expect("resolve"); + unsafe { std::env::remove_var("CLAUDE_CODE_ENABLED") }; + + assert!(!cfg.enabled); + } + #[test] fn test_readonly_policy_unaffected() { let config = SandboxModeConfig { diff --git a/src/config/wasm.rs b/src/config/wasm.rs index 224f2e9532..a9bfbd3566 100644 --- a/src/config/wasm.rs +++ b/src/config/wasm.rs @@ -44,20 +44,30 @@ fn default_tools_dir() -> PathBuf { } impl WasmConfig { - pub(crate) fn resolve() -> Result { + pub(crate) fn resolve(settings: &crate::settings::Settings) -> Result { + let ws = &settings.wasm; Ok(Self { - enabled: parse_bool_env("WASM_ENABLED", true)?, + enabled: parse_bool_env("WASM_ENABLED", ws.enabled)?, tools_dir: optional_env("WASM_TOOLS_DIR")? .map(PathBuf::from) + .or_else(|| ws.tools_dir.clone()) .unwrap_or_else(default_tools_dir), default_memory_limit: parse_optional_env( "WASM_DEFAULT_MEMORY_LIMIT", - 10 * 1024 * 1024, + ws.default_memory_limit, )?, - default_timeout_secs: parse_optional_env("WASM_DEFAULT_TIMEOUT_SECS", 60)?, - default_fuel_limit: parse_optional_env("WASM_DEFAULT_FUEL_LIMIT", 10_000_000)?, - cache_compiled: parse_bool_env("WASM_CACHE_COMPILED", true)?, - cache_dir: optional_env("WASM_CACHE_DIR")?.map(PathBuf::from), + default_timeout_secs: parse_optional_env( + "WASM_DEFAULT_TIMEOUT_SECS", + ws.default_timeout_secs, + )?, + default_fuel_limit: parse_optional_env( + "WASM_DEFAULT_FUEL_LIMIT", + ws.default_fuel_limit, + )?, + cache_compiled: parse_bool_env("WASM_CACHE_COMPILED", ws.cache_compiled)?, + cache_dir: optional_env("WASM_CACHE_DIR")? + .map(PathBuf::from) + .or_else(|| ws.cache_dir.clone()), }) } @@ -81,3 +91,36 @@ impl WasmConfig { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::helpers::ENV_MUTEX; + use crate::settings::Settings; + + #[test] + fn resolve_falls_back_to_settings() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + let mut settings = Settings::default(); + settings.wasm.default_memory_limit = 42; + settings.wasm.cache_compiled = false; + + let cfg = WasmConfig::resolve(&settings).expect("resolve"); + assert_eq!(cfg.default_memory_limit, 42); + assert!(!cfg.cache_compiled); + } + + #[test] + fn env_overrides_settings() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + let mut settings = Settings::default(); + settings.wasm.default_fuel_limit = 42; + + // SAFETY: Under ENV_MUTEX, no concurrent env access. + unsafe { std::env::set_var("WASM_DEFAULT_FUEL_LIMIT", "7") }; + let cfg = WasmConfig::resolve(&settings).expect("resolve"); + unsafe { std::env::remove_var("WASM_DEFAULT_FUEL_LIMIT") }; + + assert_eq!(cfg.default_fuel_limit, 7); + } +} diff --git a/src/context/state.rs b/src/context/state.rs index 22aca31199..768e4da6b0 100644 --- a/src/context/state.rs +++ b/src/context/state.rs @@ -48,6 +48,14 @@ impl JobState { pub fn can_transition_to(&self, target: JobState) -> bool { use JobState::*; + // Allow idempotent Completed -> Completed transition. + // Both the execution loop and the worker wrapper may race to mark a + // job complete; the second call should be a harmless no-op rather + // than an error that masks the successful completion. + if matches!((self, target), (Completed, Completed)) { + return true; + } + matches!( (self, target), // From Pending @@ -238,6 +246,18 @@ impl JobContext { )); } + // Idempotent: already in the target state, skip recording a duplicate + // transition. This handles the Completed -> Completed race between + // execution_loop and the worker wrapper. + if self.state == new_state { + tracing::debug!( + job_id = %self.job_id, + state = %self.state, + "idempotent state transition (already in target state), skipping" + ); + return Ok(()); + } + let transition = StateTransition { from: self.state, to: new_state, @@ -340,6 +360,45 @@ mod tests { assert!(!JobState::Accepted.can_transition_to(JobState::InProgress)); } + #[test] + fn test_completed_to_completed_is_idempotent() { + // Regression test for the race condition where both execution_loop + // and the worker wrapper call mark_completed(). The second call + // must succeed without error and must not record a duplicate + // transition. + let mut ctx = JobContext::new("Test", "Idempotent completion test"); + ctx.transition_to(JobState::InProgress, None).unwrap(); + ctx.transition_to(JobState::Completed, Some("first".into())) + .unwrap(); + assert_eq!(ctx.state, JobState::Completed); + let transitions_before = ctx.transitions.len(); + + // Second Completed -> Completed must be a no-op + let result = ctx.transition_to(JobState::Completed, Some("duplicate".into())); + assert!( + result.is_ok(), + "Completed -> Completed should be idempotent" + ); + assert_eq!(ctx.state, JobState::Completed); + assert_eq!( + ctx.transitions.len(), + transitions_before, + "idempotent transition should not record a new history entry" + ); + } + + #[test] + fn test_other_self_transitions_still_rejected() { + // Ensure we only allow Completed -> Completed, not arbitrary X -> X. + assert!(!JobState::Pending.can_transition_to(JobState::Pending)); + assert!(!JobState::InProgress.can_transition_to(JobState::InProgress)); + assert!(!JobState::Failed.can_transition_to(JobState::Failed)); + assert!(!JobState::Stuck.can_transition_to(JobState::Stuck)); + assert!(!JobState::Submitted.can_transition_to(JobState::Submitted)); + assert!(!JobState::Accepted.can_transition_to(JobState::Accepted)); + assert!(!JobState::Cancelled.can_transition_to(JobState::Cancelled)); + } + #[test] fn test_terminal_states() { assert!(JobState::Accepted.is_terminal()); diff --git a/src/db/mod.rs b/src/db/mod.rs index a306c14bc1..6d2eb2960c 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -104,7 +104,7 @@ pub async fn connect_with_handles( Ok((Arc::new(backend) as Arc, handles)) } #[cfg(feature = "postgres")] - _ => { + crate::config::DatabaseBackend::Postgres => { let pg = postgres::PgBackend::new(config) .await .map_err(|e| DatabaseError::Pool(e.to_string()))?; @@ -115,10 +115,11 @@ pub async fn connect_with_handles( Ok((Arc::new(pg) as Arc, handles)) } - #[cfg(not(feature = "postgres"))] - _ => Err(DatabaseError::Pool( - "No database backend available. Enable 'postgres' or 'libsql' feature.".to_string(), - )), + #[allow(unreachable_patterns)] + _ => Err(DatabaseError::Pool(format!( + "Database backend '{}' is not available. Rebuild with the appropriate feature flag.", + config.backend + ))), } } @@ -161,7 +162,7 @@ pub async fn create_secrets_store( ))) } #[cfg(feature = "postgres")] - _ => { + crate::config::DatabaseBackend::Postgres => { let pg = postgres::PgBackend::new(config) .await .map_err(|e| DatabaseError::Pool(e.to_string()))?; @@ -172,14 +173,142 @@ pub async fn create_secrets_store( crypto, ))) } - #[cfg(not(feature = "postgres"))] - _ => Err(DatabaseError::Pool( - "No database backend available for secrets. Enable 'postgres' or 'libsql' feature." - .to_string(), - )), + #[allow(unreachable_patterns)] + _ => Err(DatabaseError::Pool(format!( + "Database backend '{}' is not available for secrets. Rebuild with the appropriate feature flag.", + config.backend + ))), } } +// ==================== Wizard / testing helpers ==================== + +/// Connect to the database WITHOUT running migrations, validating +/// prerequisites when applicable (PostgreSQL version, pgvector). +/// +/// Returns both the `Database` trait object and backend-specific handles. +/// Used by the wizard to test connectivity before committing — call +/// [`Database::run_migrations`] on the returned trait object when ready. +pub async fn connect_without_migrations( + config: &crate::config::DatabaseConfig, +) -> Result<(Arc, DatabaseHandles), DatabaseError> { + let mut handles = DatabaseHandles::default(); + + match config.backend { + #[cfg(feature = "libsql")] + crate::config::DatabaseBackend::LibSql => { + use secrecy::ExposeSecret as _; + + let default_path = crate::config::default_libsql_path(); + let db_path = config.libsql_path.as_deref().unwrap_or(&default_path); + + let backend = if let Some(ref url) = config.libsql_url { + let token = config.libsql_auth_token.as_ref().ok_or_else(|| { + DatabaseError::Pool( + "LIBSQL_AUTH_TOKEN required when LIBSQL_URL is set".to_string(), + ) + })?; + libsql::LibSqlBackend::new_remote_replica(db_path, url, token.expose_secret()) + .await + .map_err(|e| DatabaseError::Pool(e.to_string()))? + } else { + libsql::LibSqlBackend::new_local(db_path) + .await + .map_err(|e| DatabaseError::Pool(e.to_string()))? + }; + + handles.libsql_db = Some(backend.shared_db()); + + Ok((Arc::new(backend) as Arc, handles)) + } + #[cfg(feature = "postgres")] + crate::config::DatabaseBackend::Postgres => { + let pg = postgres::PgBackend::new(config) + .await + .map_err(|e| DatabaseError::Pool(e.to_string()))?; + + handles.pg_pool = Some(pg.pool()); + + // Validate PostgreSQL prerequisites (version, pgvector) + validate_postgres(&pg.pool()).await?; + + Ok((Arc::new(pg) as Arc, handles)) + } + #[allow(unreachable_patterns)] + _ => Err(DatabaseError::Pool(format!( + "Database backend '{}' is not available. Rebuild with the appropriate feature flag.", + config.backend + ))), + } +} + +/// Validate PostgreSQL prerequisites (version >= 15, pgvector available). +/// +/// Returns `Ok(())` if all prerequisites are met, or a `DatabaseError` +/// with a user-facing message describing the issue. +#[cfg(feature = "postgres")] +async fn validate_postgres(pool: &deadpool_postgres::Pool) -> Result<(), DatabaseError> { + let client = pool + .get() + .await + .map_err(|e| DatabaseError::Pool(format!("Failed to connect: {}", e)))?; + + // Check PostgreSQL server version (need 15+ for pgvector). + let version_row = client + .query_one("SHOW server_version", &[]) + .await + .map_err(|e| DatabaseError::Query(format!("Failed to query server version: {}", e)))?; + let version_str: &str = version_row.get(0); + let major_version = version_str + .split('.') + .next() + .and_then(|v| v.parse::().ok()) + .ok_or_else(|| { + DatabaseError::Pool(format!( + "Could not parse PostgreSQL version from '{}'. \ + Expected a numeric major version (e.g., '15.2').", + version_str + )) + })?; + + const MIN_PG_MAJOR_VERSION: u32 = 15; + + if major_version < MIN_PG_MAJOR_VERSION { + return Err(DatabaseError::Pool(format!( + "PostgreSQL {} detected. IronClaw requires PostgreSQL {} or later \ + for pgvector support.\n\ + Upgrade: https://www.postgresql.org/download/", + version_str, MIN_PG_MAJOR_VERSION + ))); + } + + // Check if pgvector extension is available. + let pgvector_row = client + .query_opt( + "SELECT 1 FROM pg_available_extensions WHERE name = 'vector'", + &[], + ) + .await + .map_err(|e| { + DatabaseError::Query(format!("Failed to check pgvector availability: {}", e)) + })?; + + if pgvector_row.is_none() { + return Err(DatabaseError::Pool(format!( + "pgvector extension not found on your PostgreSQL server.\n\n\ + Install it:\n \ + macOS: brew install pgvector\n \ + Ubuntu: apt install postgresql-{0}-pgvector\n \ + Docker: use the pgvector/pgvector:pg{0} image\n \ + Source: https://github.com/pgvector/pgvector#installation\n\n\ + Then restart PostgreSQL and re-run: ironclaw onboard", + major_version + ))); + } + + Ok(()) +} + // ==================== Sub-traits ==================== // // Each sub-trait groups related persistence methods. The `Database` supertrait diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index e057e2acc1..d63ae446a0 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -10,16 +10,17 @@ use std::sync::Arc; use tokio::sync::RwLock; -use crate::channels::ChannelManager; use crate::channels::wasm::{ - RegisteredEndpoint, SharedWasmChannel, WasmChannelLoader, WasmChannelRouter, WasmChannelRuntime, + LoadedChannel, RegisteredEndpoint, SharedWasmChannel, TELEGRAM_CHANNEL_NAME, WasmChannelLoader, + WasmChannelRouter, WasmChannelRuntime, bot_username_setting_key, }; +use crate::channels::{ChannelManager, OutgoingResponse}; use crate::extensions::discovery::OnlineDiscovery; use crate::extensions::registry::ExtensionRegistry; use crate::extensions::{ ActivateResult, AuthResult, ConfigureResult, ExtensionError, ExtensionKind, ExtensionSource, InstallResult, InstalledExtension, RegistryEntry, ResultSource, SearchResult, ToolAuthState, - UpgradeOutcome, UpgradeResult, + UpgradeOutcome, UpgradeResult, VerificationChallenge, }; use crate::hooks::HookRegistry; use crate::pairing::PairingStore; @@ -56,7 +57,214 @@ struct ChannelRuntimeState { wasm_channel_owner_ids: std::collections::HashMap, } +#[cfg(test)] +type TestWasmChannelLoader = + Arc Result + Send + Sync>; +#[cfg(test)] +type TestTelegramBindingResolver = + Arc) -> Result + Send + Sync>; + +const TELEGRAM_OWNER_BIND_TIMEOUT_SECS: u64 = 120; +const TELEGRAM_OWNER_BIND_CHALLENGE_TTL_SECS: u64 = 300; +const TELEGRAM_GET_UPDATES_TIMEOUT_SECS: u64 = 25; +const TELEGRAM_OWNER_BIND_CODE_LEN: usize = 8; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct TelegramBindingData { + owner_id: i64, + bot_username: Option, + binding_state: TelegramOwnerBindingState, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TelegramOwnerBindingState { + Existing, + VerifiedNow, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PendingTelegramVerificationChallenge { + code: String, + bot_username: Option, + expires_at_unix: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TelegramBindingResult { + Bound(TelegramBindingData), + Pending(VerificationChallenge), +} + +fn telegram_request_error(action: &'static str, error: &reqwest::Error) -> ExtensionError { + tracing::warn!( + action, + status = error.status().map(|status| status.as_u16()), + is_timeout = error.is_timeout(), + is_connect = error.is_connect(), + "Telegram API request failed" + ); + ExtensionError::Other(format!("Telegram {action} request failed")) +} + +fn telegram_response_parse_error(action: &'static str, error: &reqwest::Error) -> ExtensionError { + tracing::warn!( + action, + status = error.status().map(|status| status.as_u16()), + is_timeout = error.is_timeout(), + "Telegram API response parse failed" + ); + ExtensionError::Other(format!("Failed to parse Telegram {action} response")) +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramGetMeResponse { + ok: bool, + #[serde(default)] + result: Option, + #[serde(default)] + description: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramGetMeUser { + #[serde(default)] + username: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramGetUpdatesResponse { + ok: bool, + #[serde(default)] + result: Vec, + #[serde(default)] + description: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramUpdate { + update_id: i64, + #[serde(default)] + message: Option, + #[serde(default)] + edited_message: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramMessage { + chat: TelegramChat, + #[serde(default)] + from: Option, + #[serde(default)] + text: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramChat { + #[serde(rename = "type")] + chat_type: String, +} + +#[derive(Debug, serde::Deserialize)] +struct TelegramUser { + id: i64, + is_bot: bool, +} + +fn build_wasm_channel_runtime_config_updates( + tunnel_url: Option<&str>, + webhook_secret: Option<&str>, + owner_id: Option, +) -> HashMap { + let mut config_updates = HashMap::new(); + + if let Some(tunnel_url) = tunnel_url { + config_updates.insert( + "tunnel_url".to_string(), + serde_json::Value::String(tunnel_url.to_string()), + ); + } + + if let Some(secret) = webhook_secret { + config_updates.insert( + "webhook_secret".to_string(), + serde_json::Value::String(secret.to_string()), + ); + } + + if let Some(owner_id) = owner_id { + config_updates.insert("owner_id".to_string(), serde_json::json!(owner_id)); + } + + config_updates +} + +fn channel_auth_instructions( + channel_name: &str, + secret: &crate::channels::wasm::SecretSetupSchema, +) -> String { + if channel_name == TELEGRAM_CHANNEL_NAME && secret.name == "telegram_bot_token" { + return format!( + "{} After you submit it, IronClaw will show a one-time verification code. Send `/start CODE` to your bot in Telegram, then verify again to bind the owner.", + secret.prompt + ); + } + + secret.prompt.clone() +} + +fn unix_timestamp_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn generate_telegram_verification_code() -> String { + use rand::Rng; + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(TELEGRAM_OWNER_BIND_CODE_LEN) + .map(char::from) + .collect::() + .to_lowercase() +} + +fn telegram_verification_deep_link(bot_username: Option<&str>, code: &str) -> Option { + bot_username + .filter(|username| !username.trim().is_empty()) + .map(|username| format!("https://t.me/{username}?start={code}")) +} + +fn telegram_verification_instructions(bot_username: Option<&str>, code: &str) -> String { + if let Some(username) = bot_username.filter(|username| !username.trim().is_empty()) { + return format!("Send `/start {code}` to @{username}, then click Verify owner."); + } + + format!("Send `/start {code}` to your Telegram bot, then click Verify owner.") +} + +fn telegram_message_matches_verification_code(text: &str, code: &str) -> bool { + let trimmed = text.trim(); + trimmed == code + || trimmed == format!("/start {code}") + || trimmed + .split_whitespace() + .map(|token| token.trim_matches(|c: char| !c.is_ascii_alphanumeric() && c != '-')) + .any(|token| token == code) +} + /// Central manager for extension lifecycle operations. +/// +/// # Initialization Order +/// +/// Relay-channel restoration depends on a channel manager being injected first. +/// Call one of the following before `restore_relay_channels()`: +/// +/// 1. [`ExtensionManager::set_channel_runtime`] (also sets relay manager), or +/// 2. [`ExtensionManager::set_relay_channel_manager`]. +/// +/// If `restore_relay_channels()` runs first, each restore attempt fails with +/// "Channel manager not initialized" and channels remain inactive. pub struct ExtensionManager { registry: ExtensionRegistry, discovery: OnlineDiscovery, @@ -115,6 +323,11 @@ pub struct ExtensionManager { /// The gateway's own base URL for building OAuth redirect URIs. /// Set by the web gateway at startup via `enable_gateway_mode()`. gateway_base_url: RwLock>, + pending_telegram_verification: RwLock>, + #[cfg(test)] + test_wasm_channel_loader: RwLock>, + #[cfg(test)] + test_telegram_binding_resolver: RwLock>, } /// Sanitize a URL for logging by removing query parameters and credentials. @@ -190,9 +403,24 @@ impl ExtensionManager { relay_config: crate::config::RelayConfig::from_env(), gateway_mode: std::sync::atomic::AtomicBool::new(false), gateway_base_url: RwLock::new(None), + pending_telegram_verification: RwLock::new(HashMap::new()), + #[cfg(test)] + test_wasm_channel_loader: RwLock::new(None), + #[cfg(test)] + test_telegram_binding_resolver: RwLock::new(None), } } + #[cfg(test)] + async fn set_test_wasm_channel_loader(&self, loader: TestWasmChannelLoader) { + *self.test_wasm_channel_loader.write().await = Some(loader); + } + + #[cfg(test)] + async fn set_test_telegram_binding_resolver(&self, resolver: TestTelegramBindingResolver) { + *self.test_telegram_binding_resolver.write().await = Some(resolver); + } + /// Enable gateway mode so OAuth flows return auth URLs to the frontend /// instead of calling `open::that()` on the server. /// @@ -298,14 +526,6 @@ impl ExtensionManager { }); } - /// Set just the channel manager for relay channel hot-activation. - /// - /// Call this when WASM channel runtime is not available but relay channels - /// still need to be hot-added. - pub async fn set_relay_channel_manager(&self, channel_manager: Arc) { - *self.relay_channel_manager.write().await = Some(channel_manager); - } - async fn current_channel_owner_id(&self, name: &str) -> Option { { let rt_guard = self.channel_runtime.read().await; @@ -334,6 +554,131 @@ impl ExtensionManager { } } + async fn set_channel_owner_id(&self, name: &str, owner_id: i64) -> Result<(), ExtensionError> { + if let Some(store) = self.store.as_ref() { + store + .set_setting( + &self.user_id, + &format!("channels.wasm_channel_owner_ids.{name}"), + &serde_json::json!(owner_id), + ) + .await + .map_err(|e| ExtensionError::Config(e.to_string()))?; + } + + let mut rt_guard = self.channel_runtime.write().await; + if let Some(rt) = rt_guard.as_mut() { + rt.wasm_channel_owner_ids.insert(name.to_string(), owner_id); + } + + Ok(()) + } + + async fn load_channel_runtime_config_overrides( + &self, + name: &str, + ) -> HashMap { + let mut overrides = HashMap::new(); + + if name == TELEGRAM_CHANNEL_NAME + && let Some(store) = self.store.as_ref() + && let Ok(Some(serde_json::Value::String(username))) = store + .get_setting(&self.user_id, &bot_username_setting_key(name)) + .await + && !username.trim().is_empty() + { + overrides.insert("bot_username".to_string(), serde_json::json!(username)); + } + + overrides + } + + pub async fn has_wasm_channel_owner_binding(&self, name: &str) -> bool { + self.current_channel_owner_id(name).await.is_some() + } + + async fn get_pending_telegram_verification( + &self, + name: &str, + ) -> Option { + let now = unix_timestamp_secs(); + let mut guard = self.pending_telegram_verification.write().await; + let challenge = guard.get(name).cloned()?; + if challenge.expires_at_unix <= now { + guard.remove(name); + return None; + } + Some(challenge) + } + + async fn set_pending_telegram_verification( + &self, + name: &str, + challenge: PendingTelegramVerificationChallenge, + ) { + self.pending_telegram_verification + .write() + .await + .insert(name.to_string(), challenge); + } + + async fn clear_pending_telegram_verification(&self, name: &str) { + self.pending_telegram_verification + .write() + .await + .remove(name); + } + + async fn issue_telegram_verification_challenge( + &self, + client: &reqwest::Client, + name: &str, + bot_token: &str, + bot_username: Option<&str>, + ) -> Result { + let delete_webhook_url = format!("https://api.telegram.org/bot{bot_token}/deleteWebhook"); + let delete_webhook_resp = client + .post(&delete_webhook_url) + .query(&[("drop_pending_updates", "true")]) + .send() + .await + .map_err(|e| telegram_request_error("deleteWebhook", &e))?; + if !delete_webhook_resp.status().is_success() { + return Err(ExtensionError::Other(format!( + "Telegram deleteWebhook failed (HTTP {})", + delete_webhook_resp.status() + ))); + } + + let challenge = PendingTelegramVerificationChallenge { + code: generate_telegram_verification_code(), + bot_username: bot_username.map(str::to_string), + expires_at_unix: unix_timestamp_secs() + TELEGRAM_OWNER_BIND_CHALLENGE_TTL_SECS, + }; + self.set_pending_telegram_verification(name, challenge.clone()) + .await; + + Ok(VerificationChallenge { + code: challenge.code.clone(), + instructions: telegram_verification_instructions( + challenge.bot_username.as_deref(), + &challenge.code, + ), + deep_link: telegram_verification_deep_link( + challenge.bot_username.as_deref(), + &challenge.code, + ), + }) + } + + /// Set just the channel manager for relay channel hot-activation. + /// + /// Call this when WASM channel runtime is not available but relay channels + /// still need to be hot-added. + pub async fn set_relay_channel_manager(&self, channel_manager: Arc) { + *self.relay_channel_manager.write().await = Some(channel_manager); + } + /// Check if a channel name corresponds to a relay extension (has stored stream token). pub async fn is_relay_channel(&self, name: &str) -> bool { self.secrets @@ -346,7 +691,10 @@ impl ExtensionManager { /// /// Loads the persisted active channel list, filters to relay types (those with /// a stored stream token), and activates each via `activate_stored_relay()`. - /// Skips channels that are already active. Call this after `set_relay_channel_manager()`. + /// Skips channels that are already active. + /// + /// Call this only after `set_relay_channel_manager()` or `set_channel_runtime()`. + /// Otherwise, each activation attempt fails with "Channel manager not initialized". pub async fn restore_relay_channels(&self) { let persisted = self.load_persisted_active_channels().await; let already_active = self.active_channel_names.read().await.clone(); @@ -2818,7 +3166,7 @@ impl ExtensionManager { Ok(AuthResult::awaiting_token( name, ExtensionKind::WasmChannel, - secret.prompt.clone(), + channel_auth_instructions(name, secret), cap_file.setup.setup_url.clone(), )) } @@ -3021,7 +3369,13 @@ impl ExtensionManager { // Verify runtime infrastructure is available and clone Arcs so we don't // hold the RwLock guard across awaits. - let (channel_runtime, channel_manager, pairing_store, wasm_channel_router) = { + let ( + channel_runtime, + channel_manager, + pairing_store, + wasm_channel_router, + wasm_channel_owner_ids, + ) = { let rt_guard = self.channel_runtime.read().await; let rt = rt_guard.as_ref().ok_or_else(|| { ExtensionError::ActivationFailed("WASM channel runtime not configured".to_string()) @@ -3031,6 +3385,7 @@ impl ExtensionManager { Arc::clone(&rt.channel_manager), Arc::clone(&rt.pairing_store), Arc::clone(&rt.wasm_channel_router), + rt.wasm_channel_owner_ids.clone(), ) }; @@ -3054,19 +3409,58 @@ impl ExtensionManager { None }; - let settings_store: Option> = - self.store.as_ref().map(|db| Arc::clone(db) as _); - let loader = WasmChannelLoader::new( - Arc::clone(&channel_runtime), - Arc::clone(&pairing_store), - settings_store, + #[cfg(test)] + let loaded = if let Some(loader) = self.test_wasm_channel_loader.read().await.as_ref() { + loader(name)? + } else { + let settings_store: Option> = + self.store.as_ref().map(|db| Arc::clone(db) as _); + let loader = WasmChannelLoader::new( + Arc::clone(&channel_runtime), + Arc::clone(&pairing_store), + settings_store, + ) + .with_secrets_store(Arc::clone(&self.secrets)); + loader + .load_from_files(name, &wasm_path, cap_path_option) + .await + .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))? + }; + + #[cfg(not(test))] + let loaded = { + let settings_store: Option> = + self.store.as_ref().map(|db| Arc::clone(db) as _); + let loader = WasmChannelLoader::new( + Arc::clone(&channel_runtime), + Arc::clone(&pairing_store), + settings_store, + ) + .with_secrets_store(Arc::clone(&self.secrets)); + loader + .load_from_files(name, &wasm_path, cap_path_option) + .await + .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))? + }; + + self.complete_loaded_wasm_channel_activation( + name, + loaded, + &channel_manager, + &wasm_channel_router, + wasm_channel_owner_ids.get(name).copied(), ) - .with_secrets_store(Arc::clone(&self.secrets)); - let loaded = loader - .load_from_files(name, &wasm_path, cap_path_option) - .await - .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))?; + .await + } + async fn complete_loaded_wasm_channel_activation( + &self, + requested_name: &str, + loaded: LoadedChannel, + channel_manager: &Arc, + wasm_channel_router: &Arc, + owner_id: Option, + ) -> Result { let channel_name = loaded.name().to_string(); let webhook_secret_name = loaded.webhook_secret_name(); let secret_header = loaded.webhook_secret_header().map(|s| s.to_string()); @@ -3085,25 +3479,16 @@ impl ExtensionManager { // Inject runtime config (tunnel_url, webhook_secret, owner_id) { - let mut config_updates = std::collections::HashMap::new(); - - if let Some(ref tunnel_url) = self.tunnel_url { - config_updates.insert( - "tunnel_url".to_string(), - serde_json::Value::String(tunnel_url.clone()), - ); - } - - if let Some(ref secret) = webhook_secret { - config_updates.insert( - "webhook_secret".to_string(), - serde_json::Value::String(secret.clone()), - ); - } - - if let Some(owner_id) = self.current_channel_owner_id(&channel_name).await { - config_updates.insert("owner_id".to_string(), serde_json::json!(owner_id)); - } + let resolved_owner_id = owner_id.or(self.current_channel_owner_id(&channel_name).await); + let mut config_updates = build_wasm_channel_runtime_config_updates( + self.tunnel_url.as_deref(), + webhook_secret.as_deref(), + resolved_owner_id, + ); + config_updates.extend( + self.load_channel_runtime_config_overrides(&channel_name) + .await, + ); if !config_updates.is_empty() { channel_arc.update_config(config_updates).await; @@ -3220,7 +3605,7 @@ impl ExtensionManager { name: channel_name, kind: ExtensionKind::WasmChannel, tools_loaded: Vec::new(), - message: format!("Channel '{}' activated and running", name), + message: format!("Channel '{}' activated and running", requested_name), }) } @@ -3300,6 +3685,14 @@ impl ExtensionManager { .as_ref() .and_then(|f| f.hmac_secret_name().map(|s| s.to_string())); + let mut config_updates = build_wasm_channel_runtime_config_updates( + self.tunnel_url.as_deref(), + None, + self.current_channel_owner_id(name).await, + ); + config_updates.extend(self.load_channel_runtime_config_overrides(name).await); + let mut should_rerun_on_start = false; + // Refresh webhook secret if let Ok(secret) = self .secrets @@ -3309,14 +3702,11 @@ impl ExtensionManager { router .update_secret(name, secret.expose().to_string()) .await; - - // Also inject the webhook_secret into the channel's runtime config - let mut config_updates = std::collections::HashMap::new(); config_updates.insert( "webhook_secret".to_string(), serde_json::Value::String(secret.expose().to_string()), ); - existing_channel.update_config(config_updates).await; + should_rerun_on_start = true; } // Refresh signature key @@ -3356,19 +3746,14 @@ impl ExtensionManager { } } - // Refresh tunnel_url in case it wasn't set at startup - if let Some(ref tunnel_url) = self.tunnel_url { - let mut config_updates = std::collections::HashMap::new(); - config_updates.insert( - "tunnel_url".to_string(), - serde_json::Value::String(tunnel_url.clone()), - ); + if !config_updates.is_empty() { existing_channel.update_config(config_updates).await; + should_rerun_on_start = true; } // Re-call on_start() to trigger webhook registration with the // now-available credentials (e.g., setWebhook for Telegram). - if cred_count > 0 { + if cred_count > 0 || should_rerun_on_start { match existing_channel.call_on_start().await { Ok(_config) => { tracing::info!( @@ -3719,6 +4104,304 @@ impl ExtensionManager { } } + async fn configure_telegram_binding( + &self, + name: &str, + secrets: &std::collections::HashMap, + ) -> Result { + let explicit_token = secrets + .get("telegram_bot_token") + .map(|v| v.trim().to_string()) + .filter(|v| !v.is_empty()); + let bot_token = if let Some(token) = explicit_token.clone() { + token + } else { + match self + .secrets + .get_decrypted(&self.user_id, "telegram_bot_token") + .await + { + Ok(secret) => { + let token = secret.expose().trim().to_string(); + if token.is_empty() { + return Err(ExtensionError::ValidationFailed( + "Telegram bot token is required before owner verification".to_string(), + )); + } + token + } + Err(crate::secrets::SecretError::NotFound(_)) => { + return Err(ExtensionError::ValidationFailed( + "Telegram bot token is required before owner verification".to_string(), + )); + } + Err(err) => { + return Err(ExtensionError::Config(format!( + "Failed to read stored Telegram bot token: {err}" + ))); + } + } + }; + + let existing_owner_id = self.current_channel_owner_id(name).await; + let binding = self + .resolve_telegram_binding(name, &bot_token, existing_owner_id) + .await?; + + match &binding { + TelegramBindingResult::Bound(data) => { + self.set_channel_owner_id(name, data.owner_id).await?; + if let Some(username) = data.bot_username.as_deref() + && let Some(store) = self.store.as_ref() + { + store + .set_setting( + &self.user_id, + &bot_username_setting_key(name), + &serde_json::json!(username), + ) + .await + .map_err(|e| ExtensionError::Config(e.to_string()))?; + } + } + TelegramBindingResult::Pending(challenge) => { + if let Some(deep_link) = challenge.deep_link.as_deref() + && let Some(username) = deep_link + .strip_prefix("https://t.me/") + .and_then(|rest| rest.split('?').next()) + .filter(|value| !value.trim().is_empty()) + && let Some(store) = self.store.as_ref() + { + store + .set_setting( + &self.user_id, + &bot_username_setting_key(name), + &serde_json::json!(username), + ) + .await + .map_err(|e| ExtensionError::Config(e.to_string()))?; + } + } + } + + Ok(binding) + } + + async fn resolve_telegram_binding( + &self, + name: &str, + bot_token: &str, + existing_owner_id: Option, + ) -> Result { + #[cfg(test)] + if let Some(resolver) = self.test_telegram_binding_resolver.read().await.as_ref() { + return resolver(bot_token, existing_owner_id); + } + + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .map_err(|e| ExtensionError::Other(e.to_string()))?; + + let get_me_url = format!("https://api.telegram.org/bot{bot_token}/getMe"); + let get_me_resp = client + .get(&get_me_url) + .send() + .await + .map_err(|e| telegram_request_error("getMe", &e))?; + let get_me_status = get_me_resp.status(); + if !get_me_status.is_success() { + return Err(ExtensionError::ValidationFailed(format!( + "Telegram token validation failed (HTTP {get_me_status})" + ))); + } + + let get_me: TelegramGetMeResponse = get_me_resp + .json() + .await + .map_err(|e| telegram_response_parse_error("getMe", &e))?; + if !get_me.ok { + return Err(ExtensionError::ValidationFailed( + get_me + .description + .unwrap_or_else(|| "Telegram getMe returned ok=false".to_string()), + )); + } + + let bot_username = get_me + .result + .and_then(|result| result.username) + .filter(|username| !username.trim().is_empty()); + + if let Some(owner_id) = existing_owner_id { + self.clear_pending_telegram_verification(name).await; + return Ok(TelegramBindingResult::Bound(TelegramBindingData { + owner_id, + bot_username: bot_username.clone(), + binding_state: TelegramOwnerBindingState::Existing, + })); + } + + let pending_challenge = self.get_pending_telegram_verification(name).await; + + let challenge = if let Some(challenge) = pending_challenge { + challenge + } else { + return Ok(TelegramBindingResult::Pending( + self.issue_telegram_verification_challenge( + &client, + name, + bot_token, + bot_username.as_deref(), + ) + .await?, + )); + }; + + let now = unix_timestamp_secs(); + if challenge.expires_at_unix <= now { + self.clear_pending_telegram_verification(name).await; + return Ok(TelegramBindingResult::Pending( + self.issue_telegram_verification_challenge( + &client, + name, + bot_token, + bot_username.as_deref(), + ) + .await?, + )); + } + + let deadline = std::time::Instant::now() + + std::time::Duration::from_secs(TELEGRAM_OWNER_BIND_TIMEOUT_SECS); + let mut offset = 0_i64; + + while std::time::Instant::now() < deadline { + let remaining_secs = deadline + .saturating_duration_since(std::time::Instant::now()) + .as_secs() + .max(1); + let poll_timeout_secs = TELEGRAM_GET_UPDATES_TIMEOUT_SECS.min(remaining_secs); + + let resp = client + .get(format!( + "https://api.telegram.org/bot{bot_token}/getUpdates" + )) + .query(&[ + ("offset", offset.to_string()), + ("timeout", poll_timeout_secs.to_string()), + ( + "allowed_updates", + "[\"message\",\"edited_message\"]".to_string(), + ), + ]) + .send() + .await + .map_err(|e| telegram_request_error("getUpdates", &e))?; + + if !resp.status().is_success() { + return Err(ExtensionError::Other(format!( + "Telegram getUpdates failed (HTTP {})", + resp.status() + ))); + } + + let updates: TelegramGetUpdatesResponse = resp + .json() + .await + .map_err(|e| telegram_response_parse_error("getUpdates", &e))?; + + if !updates.ok { + return Err(ExtensionError::Other(updates.description.unwrap_or_else( + || "Telegram getUpdates returned ok=false".to_string(), + ))); + } + + let mut bound_owner_id = None; + for update in updates.result { + offset = offset.max(update.update_id + 1); + let message = update.message.or(update.edited_message); + if let Some(message) = message + && message.chat.chat_type == "private" + && let Some(from) = message.from + && !from.is_bot + && let Some(text) = message.text.as_deref() + && telegram_message_matches_verification_code(text, &challenge.code) + { + bound_owner_id = Some(from.id); + } + } + + if let Some(owner_id) = bound_owner_id { + self.clear_pending_telegram_verification(name).await; + if offset > 0 { + let _ = client + .get(format!( + "https://api.telegram.org/bot{bot_token}/getUpdates" + )) + .query(&[("offset", offset.to_string()), ("timeout", "0".to_string())]) + .send() + .await; + } + + return Ok(TelegramBindingResult::Bound(TelegramBindingData { + owner_id, + bot_username, + binding_state: TelegramOwnerBindingState::VerifiedNow, + })); + } + } + + Err(ExtensionError::ValidationFailed(format!( + "Telegram owner verification timed out. Send `/start {}` to your bot, then click Verify owner again.", + challenge.code + ))) + } + + async fn notify_telegram_owner_verified( + &self, + channel_name: &str, + binding: Option<&TelegramBindingData>, + ) { + let Some(binding) = binding else { + return; + }; + if binding.binding_state != TelegramOwnerBindingState::VerifiedNow { + return; + } + + let channel_manager = { + let rt_guard = self.channel_runtime.read().await; + rt_guard.as_ref().map(|rt| Arc::clone(&rt.channel_manager)) + }; + let Some(channel_manager) = channel_manager else { + tracing::debug!( + channel = channel_name, + owner_id = binding.owner_id, + "Skipping Telegram owner confirmation message because channel runtime is unavailable" + ); + return; + }; + + if let Err(err) = channel_manager + .broadcast( + channel_name, + &binding.owner_id.to_string(), + OutgoingResponse::text( + "Telegram owner verified. This bot is now active and ready for you.", + ), + ) + .await + { + tracing::warn!( + channel = channel_name, + owner_id = binding.owner_id, + error = %err, + "Failed to send Telegram owner verification confirmation" + ); + } + } + /// Save setup secrets for an extension, validating names against the capabilities schema. /// /// Configure secrets for an extension: validate, store, auto-generate, and activate. @@ -3817,9 +4500,16 @@ impl ExtensionManager { { let token = token_value.trim(); if !token.is_empty() { - let encoded = - url::form_urlencoded::byte_serialize(token.as_bytes()).collect::(); - let url = endpoint_template.replace(&format!("{{{}}}", secret_def.name), &encoded); + // Telegram tokens contain colons (numeric_id:token_part) in the URL path, + // not query parameters, so URL-encoding breaks the endpoint. + // For other extensions, keep encoding to handle special chars in query parameters. + let url = if name == "telegram" { + endpoint_template.replace(&format!("{{{}}}", secret_def.name), token) + } else { + let encoded = + url::form_urlencoded::byte_serialize(token.as_bytes()).collect::(); + endpoint_template.replace(&format!("{{{}}}", secret_def.name), &encoded) + }; // SSRF defense: block private IPs, localhost, cloud metadata endpoints crate::tools::builtin::skill_tools::validate_fetch_url(&url) .map_err(|e| ExtensionError::Other(format!("SSRF blocked: {}", e)))?; @@ -3897,6 +4587,26 @@ impl ExtensionManager { } } + let mut telegram_binding = None; + if kind == ExtensionKind::WasmChannel && name == TELEGRAM_CHANNEL_NAME { + match self.configure_telegram_binding(name, secrets).await? { + TelegramBindingResult::Bound(binding) => { + telegram_binding = Some(binding); + } + TelegramBindingResult::Pending(verification) => { + return Ok(ConfigureResult { + message: format!( + "Configuration saved for '{}'. {}", + name, verification.instructions + ), + activated: false, + auth_url: None, + verification: Some(verification), + }); + } + } + } + // For tools, save and attempt auto-activation, then check auth. if kind == ExtensionKind::WasmTool { match self.activate_wasm_tool(name).await { @@ -3948,6 +4658,7 @@ impl ExtensionManager { message, activated: true, auth_url, + verification: None, }); } Err(e) => { @@ -3960,6 +4671,7 @@ impl ExtensionManager { message: format!("Configuration saved for '{}'.", name), activated: false, auth_url: None, + verification: None, }); } } @@ -3977,6 +4689,7 @@ impl ExtensionManager { message: format!("Configuration saved for '{}'.", name), activated: false, auth_url: None, + verification: None, }); } }; @@ -3985,13 +4698,26 @@ impl ExtensionManager { Ok(result) => { self.activation_errors.write().await.remove(name); self.broadcast_extension_status(name, "active", None).await; - Ok(ConfigureResult { - message: format!( + if name == TELEGRAM_CHANNEL_NAME { + self.notify_telegram_owner_verified(name, telegram_binding.as_ref()) + .await; + } + let message = if name == TELEGRAM_CHANNEL_NAME { + format!( + "Configuration saved, Telegram owner verified, and '{}' activated. {}", + name, result.message + ) + } else { + format!( "Configuration saved and '{}' activated. {}", name, result.message - ), + ) + }; + Ok(ConfigureResult { + message, activated: true, auth_url: None, + verification: None, }) } Err(e) => { @@ -4014,6 +4740,7 @@ impl ExtensionManager { ), activated: false, auth_url: None, + verification: None, }) } } @@ -4373,13 +5100,101 @@ fn combine_install_errors( #[cfg(test)] mod tests { + use std::fmt::Debug; use std::sync::Arc; + use async_trait::async_trait; + use futures::stream; + + use crate::channels::wasm::{ + ChannelCapabilities, LoadedChannel, PreparedChannelModule, WasmChannel, WasmChannelRouter, + WasmChannelRuntime, WasmChannelRuntimeConfig, bot_username_setting_key, + }; + use crate::channels::{ + Channel, ChannelManager, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate, + }; use crate::extensions::ExtensionManager; use crate::extensions::manager::{ - FallbackDecision, combine_install_errors, fallback_decision, infer_kind_from_url, + ChannelRuntimeState, FallbackDecision, TelegramBindingData, TelegramBindingResult, + TelegramOwnerBindingState, build_wasm_channel_runtime_config_updates, + combine_install_errors, fallback_decision, infer_kind_from_url, + telegram_message_matches_verification_code, }; - use crate::extensions::{ExtensionError, ExtensionKind, ExtensionSource, InstallResult}; + use crate::extensions::{ + ExtensionError, ExtensionKind, ExtensionSource, InstallResult, VerificationChallenge, + }; + use crate::pairing::PairingStore; + + fn require(condition: bool, message: impl Into) -> Result<(), String> { + if condition { + Ok(()) + } else { + Err(message.into()) + } + } + + fn require_eq(actual: T, expected: T, label: &str) -> Result<(), String> + where + T: PartialEq + Debug, + { + if actual == expected { + Ok(()) + } else { + Err(format!( + "{label} mismatch: expected {:?}, got {:?}", + expected, actual + )) + } + } + + #[derive(Clone)] + struct RecordingChannel { + name: String, + broadcasts: Arc>>, + } + + #[async_trait] + impl Channel for RecordingChannel { + fn name(&self) -> &str { + &self.name + } + + async fn start(&self) -> Result { + Ok(Box::pin(stream::empty())) + } + + async fn respond( + &self, + _msg: &IncomingMessage, + _response: OutgoingResponse, + ) -> Result<(), crate::error::ChannelError> { + Ok(()) + } + + async fn send_status( + &self, + _status: StatusUpdate, + _metadata: &serde_json::Value, + ) -> Result<(), crate::error::ChannelError> { + Ok(()) + } + + async fn broadcast( + &self, + user_id: &str, + response: OutgoingResponse, + ) -> Result<(), crate::error::ChannelError> { + self.broadcasts + .lock() + .await + .push((user_id.to_string(), response)); + Ok(()) + } + + async fn health_check(&self) -> Result<(), crate::error::ChannelError> { + Ok(()) + } + } #[test] fn test_infer_kind_from_url() { @@ -4762,7 +5577,10 @@ mod tests { std::fs::create_dir_all(&channels_dir).ok(); let master_key = secrecy::SecretString::from(TEST_CRYPTO_KEY.to_string()); - let crypto = Arc::new(SecretsCrypto::new(master_key).unwrap()); + let crypto = Arc::new( + SecretsCrypto::new(master_key) + .unwrap_or_else(|err| panic!("failed to construct test crypto: {err}")), + ); ExtensionManager::new( Arc::new(McpSessionManager::new()), @@ -4780,6 +5598,56 @@ mod tests { ) } + fn make_test_loaded_channel( + runtime: Arc, + name: &str, + pairing_store: Arc, + ) -> LoadedChannel { + let prepared = Arc::new(PreparedChannelModule::for_testing( + name, + format!("Mock channel: {}", name), + )); + let capabilities = + ChannelCapabilities::for_channel(name).with_path(format!("/webhook/{}", name)); + + LoadedChannel { + channel: WasmChannel::new( + runtime, + prepared, + capabilities, + "{}".to_string(), + pairing_store, + None, + ), + capabilities_file: None, + } + } + + #[test] + fn test_telegram_hot_activation_runtime_config_includes_owner_id() -> Result<(), String> { + let updates = build_wasm_channel_runtime_config_updates( + Some("https://example.test"), + Some("secret-123"), + Some(424242), + ); + + require_eq( + updates.get("tunnel_url"), + Some(&serde_json::json!("https://example.test")), + "tunnel_url", + )?; + require_eq( + updates.get("webhook_secret"), + Some(&serde_json::json!("secret-123")), + "webhook_secret", + )?; + require_eq( + updates.get("owner_id"), + Some(&serde_json::json!(424242)), + "owner_id", + ) + } + #[tokio::test] async fn test_current_channel_owner_id_uses_runtime_state() -> Result<(), String> { let manager = make_manager_with_temp_dirs(); @@ -4813,6 +5681,280 @@ mod tests { Ok(()) } + #[cfg(feature = "libsql")] + #[tokio::test] + async fn test_telegram_hot_activation_configure_uses_mock_loader_and_persists_state() + -> Result<(), String> { + let dir = tempfile::tempdir().map_err(|err| format!("temp dir: {err}"))?; + let channels_dir = dir.path().join("channels"); + std::fs::create_dir_all(&channels_dir).map_err(|err| format!("channels dir: {err}"))?; + std::fs::write(channels_dir.join("telegram.wasm"), b"mock") + .map_err(|err| format!("write wasm: {err}"))?; + std::fs::write( + channels_dir.join("telegram.capabilities.json"), + serde_json::to_vec(&serde_json::json!({ + "type": "channel", + "name": "telegram", + "setup": { + "required_secrets": [ + { + "name": "telegram_bot_token", + "prompt": "Enter your Telegram Bot API token (from @BotFather)", + "optional": false + } + ] + }, + "capabilities": { + "channel": { + "allowed_paths": ["/webhook/telegram"] + } + }, + "config": { + "owner_id": null + } + })) + .map_err(|err| format!("serialize capabilities: {err}"))?, + ) + .map_err(|err| format!("write capabilities: {err}"))?; + + let (db, _db_tmp) = crate::testing::test_db().await; + let manager = { + use crate::secrets::{InMemorySecretsStore, SecretsCrypto}; + use crate::testing::credentials::TEST_CRYPTO_KEY; + use crate::tools::ToolRegistry; + use crate::tools::mcp::process::McpProcessManager; + use crate::tools::mcp::session::McpSessionManager; + + let master_key = secrecy::SecretString::from(TEST_CRYPTO_KEY.to_string()); + let crypto = Arc::new( + SecretsCrypto::new(master_key) + .unwrap_or_else(|err| panic!("failed to construct test crypto: {err}")), + ); + + ExtensionManager::new( + Arc::new(McpSessionManager::new()), + Arc::new(McpProcessManager::new()), + Arc::new(InMemorySecretsStore::new(crypto)), + Arc::new(ToolRegistry::new()), + None, + None, + dir.path().join("tools"), + channels_dir.clone(), + None, + "test".to_string(), + Some(db), + Vec::new(), + ) + }; + + let channel_manager = Arc::new(ChannelManager::new()); + let runtime = Arc::new( + WasmChannelRuntime::new(WasmChannelRuntimeConfig::for_testing()) + .map_err(|err| format!("runtime: {err}"))?, + ); + let pairing_store = Arc::new(PairingStore::with_base_dir( + dir.path().join("pairing-state"), + )); + let router = Arc::new(WasmChannelRouter::new()); + manager + .set_channel_runtime( + Arc::clone(&channel_manager), + Arc::clone(&runtime), + Arc::clone(&pairing_store), + Arc::clone(&router), + std::collections::HashMap::new(), + ) + .await; + manager + .set_test_wasm_channel_loader(Arc::new({ + let runtime = Arc::clone(&runtime); + let pairing_store = Arc::clone(&pairing_store); + move |name| { + Ok(make_test_loaded_channel( + Arc::clone(&runtime), + name, + Arc::clone(&pairing_store), + )) + } + })) + .await; + manager + .set_test_telegram_binding_resolver(Arc::new(|_token, existing_owner_id| { + if existing_owner_id.is_some() { + return Err(ExtensionError::Other( + "owner binding should be derived during setup".to_string(), + )); + } + Ok(TelegramBindingResult::Bound(TelegramBindingData { + owner_id: 424242, + bot_username: Some("test_hot_bot".to_string()), + binding_state: TelegramOwnerBindingState::VerifiedNow, + })) + })) + .await; + + manager + .activation_errors + .write() + .await + .insert("telegram".to_string(), "stale failure".to_string()); + + let result = manager + .configure( + "telegram", + &std::collections::HashMap::from([( + "telegram_bot_token".to_string(), + "123456789:ABCdefGhI".to_string(), + )]), + ) + .await + .map_err(|err| format!("configure succeeds: {err}"))?; + + require(result.activated, "expected hot activation to succeed")?; + require( + result.message.contains("activated"), + format!("unexpected message: {}", result.message), + )?; + require( + !manager + .activation_errors + .read() + .await + .contains_key("telegram"), + "successful configure should clear stale activation errors", + )?; + require( + manager + .active_channel_names + .read() + .await + .contains("telegram"), + "telegram should be marked active after hot activation", + )?; + require( + channel_manager.get_channel("telegram").await.is_some(), + "telegram should be hot-added to the running channel manager", + )?; + require_eq( + manager.load_persisted_active_channels().await, + vec!["telegram".to_string()], + "persisted active channels", + )?; + require_eq( + manager.current_channel_owner_id("telegram").await, + Some(424242), + "current owner id", + )?; + require( + manager.has_wasm_channel_owner_binding("telegram").await, + "telegram should report an explicit owner binding after setup".to_string(), + )?; + let owner_setting = manager + .store + .as_ref() + .ok_or_else(|| "db-backed manager missing".to_string())? + .get_setting("test", "channels.wasm_channel_owner_ids.telegram") + .await + .map_err(|err| format!("owner_id setting query: {err}"))?; + require_eq( + owner_setting, + Some(serde_json::json!(424242)), + "owner setting", + )?; + let bot_username_setting = manager + .store + .as_ref() + .ok_or_else(|| "db-backed manager missing".to_string())? + .get_setting("test", &bot_username_setting_key("telegram")) + .await + .map_err(|err| format!("bot username setting query: {err}"))?; + require_eq( + bot_username_setting, + Some(serde_json::json!("test_hot_bot")), + "bot username setting", + ) + } + + #[tokio::test] + async fn test_telegram_hot_activation_returns_verification_challenge_before_binding() + -> Result<(), String> { + let dir = tempfile::tempdir().map_err(|err| format!("temp dir: {err}"))?; + let channels_dir = dir.path().join("channels"); + std::fs::create_dir_all(&channels_dir).map_err(|err| format!("channels dir: {err}"))?; + std::fs::write(channels_dir.join("telegram.wasm"), b"mock") + .map_err(|err| format!("write wasm: {err}"))?; + std::fs::write( + channels_dir.join("telegram.capabilities.json"), + serde_json::to_vec(&serde_json::json!({ + "type": "channel", + "name": "telegram", + "setup": { + "required_secrets": [ + { + "name": "telegram_bot_token", + "prompt": "Enter your Telegram Bot API token (from @BotFather)", + "optional": false + } + ] + }, + "capabilities": { + "channel": { + "allowed_paths": ["/webhook/telegram"] + } + } + })) + .map_err(|err| format!("serialize capabilities: {err}"))?, + ) + .map_err(|err| format!("write capabilities: {err}"))?; + + let manager = + make_manager_custom_dirs(dir.path().join("tools"), dir.path().join("channels")); + manager + .set_test_telegram_binding_resolver(Arc::new(|_token, existing_owner_id| { + if existing_owner_id.is_some() { + return Err(ExtensionError::Other( + "owner binding should not exist before verification".to_string(), + )); + } + Ok(TelegramBindingResult::Pending(VerificationChallenge { + code: "iclaw-7qk2m9".to_string(), + instructions: + "Send `/start iclaw-7qk2m9` to @test_hot_bot, then click Verify owner." + .to_string(), + deep_link: Some("https://t.me/test_hot_bot?start=iclaw-7qk2m9".to_string()), + })) + })) + .await; + + let result = manager + .configure( + "telegram", + &std::collections::HashMap::from([( + "telegram_bot_token".to_string(), + "123456789:ABCdefGhI".to_string(), + )]), + ) + .await + .map_err(|err| format!("configure returned challenge: {err}"))?; + + require( + !result.activated, + "expected setup to pause for verification", + )?; + require( + result.verification.as_ref().map(|v| v.code.as_str()) == Some("iclaw-7qk2m9"), + "expected verification code in configure result", + )?; + require( + !manager + .active_channel_names + .read() + .await + .contains("telegram"), + "telegram should not activate until owner verification completes", + ) + } + #[cfg(feature = "libsql")] #[tokio::test] async fn test_current_channel_owner_id_uses_store_fallback() -> Result<(), String> { @@ -4900,6 +6042,104 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_notify_telegram_owner_verified_sends_confirmation_for_new_binding() + -> Result<(), String> { + let dir = tempfile::tempdir().map_err(|err| format!("temp dir: {err}"))?; + let manager = + make_manager_custom_dirs(dir.path().join("tools"), dir.path().join("channels")); + + let channel_manager = Arc::new(ChannelManager::new()); + let broadcasts = Arc::new(tokio::sync::Mutex::new(Vec::new())); + channel_manager + .add(Box::new(RecordingChannel { + name: "telegram".to_string(), + broadcasts: Arc::clone(&broadcasts), + })) + .await; + + manager + .channel_runtime + .write() + .await + .replace(ChannelRuntimeState { + channel_manager, + wasm_channel_runtime: Arc::new( + WasmChannelRuntime::new(WasmChannelRuntimeConfig::for_testing()) + .map_err(|err| format!("runtime: {err}"))?, + ), + pairing_store: Arc::new(PairingStore::with_base_dir(dir.path().join("pairing"))), + wasm_channel_router: Arc::new(WasmChannelRouter::new()), + wasm_channel_owner_ids: std::collections::HashMap::new(), + }); + + manager + .notify_telegram_owner_verified( + "telegram", + Some(&TelegramBindingData { + owner_id: 424242, + bot_username: Some("test_hot_bot".to_string()), + binding_state: TelegramOwnerBindingState::VerifiedNow, + }), + ) + .await; + + let sent = broadcasts.lock().await; + require_eq(sent.len(), 1, "broadcast count")?; + require_eq(sent[0].0.clone(), "424242".to_string(), "broadcast user_id")?; + require( + sent[0].1.content.contains("Telegram owner verified"), + "confirmation DM should acknowledge owner verification", + ) + } + + #[tokio::test] + async fn test_notify_telegram_owner_verified_skips_existing_binding() -> Result<(), String> { + let dir = tempfile::tempdir().map_err(|err| format!("temp dir: {err}"))?; + let manager = + make_manager_custom_dirs(dir.path().join("tools"), dir.path().join("channels")); + + let channel_manager = Arc::new(ChannelManager::new()); + let broadcasts = Arc::new(tokio::sync::Mutex::new(Vec::new())); + channel_manager + .add(Box::new(RecordingChannel { + name: "telegram".to_string(), + broadcasts: Arc::clone(&broadcasts), + })) + .await; + + manager + .channel_runtime + .write() + .await + .replace(ChannelRuntimeState { + channel_manager, + wasm_channel_runtime: Arc::new( + WasmChannelRuntime::new(WasmChannelRuntimeConfig::for_testing()) + .map_err(|err| format!("runtime: {err}"))?, + ), + pairing_store: Arc::new(PairingStore::with_base_dir(dir.path().join("pairing"))), + wasm_channel_router: Arc::new(WasmChannelRouter::new()), + wasm_channel_owner_ids: std::collections::HashMap::new(), + }); + + manager + .notify_telegram_owner_verified( + "telegram", + Some(&TelegramBindingData { + owner_id: 424242, + bot_username: Some("test_hot_bot".to_string()), + binding_state: TelegramOwnerBindingState::Existing, + }), + ) + .await; + + require( + broadcasts.lock().await.is_empty(), + "existing owner bindings should not trigger another confirmation DM", + ) + } + // ── resolve_env_credentials tests ──────────────────────────────────── #[test] @@ -5588,6 +6828,77 @@ mod tests { ); } + #[tokio::test] + async fn test_telegram_auth_instructions_include_owner_verification_guidance() + -> Result<(), String> { + let dir = tempfile::tempdir().map_err(|err| format!("temp dir: {err}"))?; + let channels_dir = dir.path().join("channels"); + std::fs::create_dir_all(&channels_dir).map_err(|err| format!("channels dir: {err}"))?; + + std::fs::write(channels_dir.join("telegram.wasm"), b"\0asm fake") + .map_err(|err| format!("write wasm: {err}"))?; + let caps = serde_json::json!({ + "type": "channel", + "name": "telegram", + "setup": { + "required_secrets": [ + { + "name": "telegram_bot_token", + "prompt": "Enter your Telegram Bot API token (from @BotFather)" + } + ] + } + }); + std::fs::write( + channels_dir.join("telegram.capabilities.json"), + serde_json::to_string(&caps).map_err(|err| format!("serialize caps: {err}"))?, + ) + .map_err(|err| format!("write caps: {err}"))?; + + let mgr = make_manager_custom_dirs(dir.path().join("tools"), channels_dir); + + let result = mgr + .auth("telegram") + .await + .map_err(|err| format!("telegram auth status: {err}"))?; + let instructions = result + .instructions() + .ok_or_else(|| "awaiting token instructions missing".to_string())?; + + require( + instructions.contains("Telegram Bot API token"), + "telegram auth instructions should still ask for the bot token", + )?; + require( + instructions.contains("one-time verification code") + && instructions.contains("/start CODE"), + "telegram auth instructions should explain the owner verification step", + ) + } + + #[test] + fn test_telegram_message_matches_verification_code_variants() -> Result<(), String> { + require( + telegram_message_matches_verification_code("iclaw-7qk2m9", "iclaw-7qk2m9"), + "plain verification code should match", + )?; + require( + telegram_message_matches_verification_code("/start iclaw-7qk2m9", "iclaw-7qk2m9"), + "/start payload should match", + )?; + require( + telegram_message_matches_verification_code( + "Hi! My code is: iclaw-7qk2m9", + "iclaw-7qk2m9", + ), + "conversational message containing the code should match", + )?; + require( + !telegram_message_matches_verification_code("/start something-else", "iclaw-7qk2m9"), + "wrong verification code should not match", + ) + } + #[tokio::test] async fn test_configure_dispatches_activation_by_kind() { // Regression: configure() must dispatch to the correct activation method @@ -5668,4 +6979,34 @@ mod tests { "Display should contain 'validation failed', got: {msg}" ); } + + #[test] + fn test_telegram_token_colon_preserved_in_validation_url() { + // Regression: Telegram tokens (format: numeric_id:alphanumeric_string) must NOT + // have their colon URL-encoded to %3A, as this breaks the validation endpoint. + // Previously: form_urlencoded::byte_serialize encoded the token, causing 404s. + // Fixed by removing URL-encoding and using the token directly. + let endpoint_template = "https://api.telegram.org/bot{telegram_bot_token}/getMe"; + let secret_name = "telegram_bot_token"; + let token = "123456789:AABBccDDeeFFgg_Test-Token"; + + // Simulate the fixed validation URL building logic + let url = endpoint_template.replace(&format!("{{{}}}", secret_name), token); + + // Verify colon is preserved + let expected = "https://api.telegram.org/bot123456789:AABBccDDeeFFgg_Test-Token/getMe"; + if url != expected { + panic!("URL mismatch: expected {expected}, got {url}"); // safety: test assertion + } + + // Verify it does NOT contain the broken percent-encoded version + if url.contains("%3A") { + panic!("URL contains URL-encoded colon (%3A): {url}"); // safety: test assertion + } + + // Verify the URL contains the original colon + if !url.contains("123456789:AABBccDDeeFFgg_Test-Token") { + panic!("URL missing token: {url}"); // safety: test assertion + } + } } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 428d9b42c5..2a4d189f8e 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -453,6 +453,17 @@ pub struct ActivateResult { /// /// Returned by `ExtensionManager::configure()`, the single entrypoint /// for providing secrets to any extension (chat auth, gateway setup, etc.). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct VerificationChallenge { + /// One-time code the user must send back to the integration. + pub code: String, + /// Human-readable instructions for completing verification. + pub instructions: String, + /// Deep-link or shortcut URL that prefills the verification payload when supported. + #[serde(skip_serializing_if = "Option::is_none")] + pub deep_link: Option, +} + #[derive(Debug, Clone)] pub struct ConfigureResult { /// Human-readable status message. @@ -461,6 +472,8 @@ pub struct ConfigureResult { pub activated: bool, /// OAuth authorization URL (if OAuth flow was started). pub auth_url: Option, + /// Pending manual verification challenge (for Telegram owner binding, etc.). + pub verification: Option, } fn default_true() -> bool { diff --git a/src/llm/CLAUDE.md b/src/llm/CLAUDE.md index d1b9eea256..38d6901058 100644 --- a/src/llm/CLAUDE.md +++ b/src/llm/CLAUDE.md @@ -7,8 +7,12 @@ Multi-provider LLM integration with circuit breaker, retry, failover, and respon | File | Role | |------|------| | `mod.rs` | Provider factory (`create_llm_provider`, `build_provider_chain`); `LlmBackend` enum | +| `config.rs` | LLM config types (`LlmConfig`, `RegistryProviderConfig`, `NearAiConfig`, `BedrockConfig`) | +| `error.rs` | `LlmError` enum used by all providers | | `provider.rs` | `LlmProvider` trait, `ChatMessage`, `ToolCall`, `CompletionRequest`, `sanitize_tool_messages` | | `nearai_chat.rs` | NEAR AI Chat Completions provider (dual auth: session token or API key) | +| `codex_auth.rs` | Reads Codex CLI `auth.json`, extracts tokens, refreshes ChatGPT OAuth access tokens | +| `codex_chatgpt.rs` | Custom Responses API provider for Codex ChatGPT backend (`/backend-api/codex`) | | `reasoning.rs` | `Reasoning` struct, `ReasoningContext`, `RespondResult`, `ActionPlan`, `ToolSelection`; thinking-tag stripping; `SILENT_REPLY_TOKEN` | | `session.rs` | NEAR AI session token management with disk + DB persistence, OAuth login flow | | `circuit_breaker.rs` | Circuit breaker: Closed → Open → HalfOpen state machine | @@ -35,6 +39,12 @@ Set via `LLM_BACKEND` env var: | `tinfoil` | Tinfoil TEE inference | `TINFOIL_API_KEY`, `TINFOIL_MODEL` | | `bedrock` | AWS Bedrock (requires `--features bedrock`) | `BEDROCK_REGION`, `BEDROCK_MODEL`, `AWS_PROFILE` | +Codex auth reuse: +- Set `LLM_USE_CODEX_AUTH=true` to load credentials from `~/.codex/auth.json` (override with `CODEX_AUTH_PATH`). +- If Codex is logged in with API-key mode, IronClaw uses the standard OpenAI endpoint. +- If Codex is logged in with ChatGPT OAuth mode, IronClaw routes to the private `chatgpt.com/backend-api/codex` Responses API via `codex_chatgpt.rs`. +- ChatGPT mode supports one automatic 401 refresh using the refresh token persisted in `auth.json`. + ## AWS Bedrock Provider Uses the native Converse API via `aws-sdk-bedrockruntime` (`bedrock.rs`). Requires `--features bedrock` at build time — not in default features due to heavy AWS SDK dependencies. diff --git a/src/llm/anthropic_oauth.rs b/src/llm/anthropic_oauth.rs index 12ca223ca9..12c527f1a6 100644 --- a/src/llm/anthropic_oauth.rs +++ b/src/llm/anthropic_oauth.rs @@ -34,7 +34,9 @@ const DEFAULT_MAX_TOKENS: u32 = 8192; /// Anthropic provider using OAuth Bearer authentication. pub struct AnthropicOAuthProvider { client: Client, - token: SecretString, + /// OAuth token, wrapped in RwLock so it can be updated after a successful + /// Keychain refresh (fixes #1136: stale token reuse after expiry). + token: std::sync::RwLock, model: String, base_url: Option, active_model: std::sync::RwLock, @@ -71,7 +73,7 @@ impl AnthropicOAuthProvider { Ok(Self { client, - token, + token: std::sync::RwLock::new(token), model: config.model.clone(), base_url, active_model, @@ -98,6 +100,22 @@ impl AnthropicOAuthProvider { } } + /// Read the current token from the RwLock. + fn current_token(&self) -> String { + match self.token.read() { + Ok(guard) => guard.expose_secret().to_string(), + Err(poisoned) => poisoned.into_inner().expose_secret().to_string(), + } + } + + /// Update the stored token after a successful Keychain refresh. + fn update_token(&self, new_token: SecretString) { + match self.token.write() { + Ok(mut guard) => *guard = new_token, + Err(poisoned) => *poisoned.into_inner() = new_token, + } + } + async fn send_request Deserialize<'de>>( &self, body: &AnthropicRequest, @@ -109,7 +127,7 @@ impl AnthropicOAuthProvider { let response = self .client .post(&url) - .bearer_auth(self.token.expose_secret()) + .bearer_auth(self.current_token()) .header("anthropic-version", ANTHROPIC_API_VERSION) .header("anthropic-beta", ANTHROPIC_OAUTH_BETA) .header("Content-Type", "application/json") @@ -141,6 +159,11 @@ impl AnthropicOAuthProvider { // OAuth tokens from `claude login` expire in ~8-12h. Attempt // to re-extract a fresh token from the OS credential store // (macOS Keychain / Linux credentials file) before giving up. + // + // Brief delay to give Claude Code time to complete its async + // Keychain refresh write (fixes race in #1136). + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + if let Some(fresh) = crate::config::ClaudeCodeConfig::extract_oauth_token() { let fresh_token = SecretString::from(fresh); // Retry once with the refreshed token @@ -159,6 +182,11 @@ impl AnthropicOAuthProvider { reason: e.to_string(), })?; if retry.status().is_success() { + // Persist the refreshed token so subsequent requests + // don't hit 401 again (fixes #1136). + self.update_token(fresh_token); + tracing::info!("Anthropic OAuth token refreshed from credential store"); + let text = retry.text().await.map_err(|e| LlmError::RequestFailed { provider: "anthropic_oauth".to_string(), reason: format!("Failed to read response body: {}", e), @@ -659,4 +687,22 @@ mod tests { assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].name, "search"); } + + /// Regression test for #1136: token field must be mutable via RwLock + /// so that a refreshed token persists across subsequent requests. + #[test] + fn test_token_update_persists() { + let original = SecretString::from("old_token".to_string()); + let token = std::sync::RwLock::new(original); + + // Read the original + assert_eq!(token.read().unwrap().expose_secret(), "old_token"); + + // Simulate a successful refresh + let refreshed = SecretString::from("new_token".to_string()); + *token.write().unwrap() = refreshed; + + // Subsequent reads see the updated token + assert_eq!(token.read().unwrap().expose_secret(), "new_token"); + } } diff --git a/src/llm/codex_auth.rs b/src/llm/codex_auth.rs new file mode 100644 index 0000000000..6f302436c5 --- /dev/null +++ b/src/llm/codex_auth.rs @@ -0,0 +1,377 @@ +//! Read Codex CLI credentials for LLM authentication. +//! +//! When `LLM_USE_CODEX_AUTH=true`, IronClaw reads the Codex CLI's +//! `auth.json` file (default: `~/.codex/auth.json`) and extracts +//! credentials. This lets IronClaw piggyback on a Codex login without +//! implementing its own OAuth flow. +//! +//! Codex supports two auth modes: +//! - **API key** (`auth_mode: "apiKey"`) → uses `OPENAI_API_KEY` field +//! against `api.openai.com/v1`. +//! - **ChatGPT** (`auth_mode: "chatgpt"`) → uses `tokens.access_token` +//! (OAuth JWT) against `chatgpt.com/backend-api/codex`. +//! +//! When in ChatGPT mode, the provider supports automatic token refresh +//! on 401 responses using the `refresh_token` from `auth.json`. + +use std::path::{Path, PathBuf}; + +use secrecy::{ExposeSecret, SecretString}; +use serde::{Deserialize, Serialize}; + +/// ChatGPT backend API endpoint used by Codex in ChatGPT auth mode. +const CHATGPT_BACKEND_URL: &str = "https://chatgpt.com/backend-api/codex"; + +/// Standard OpenAI API endpoint used by Codex in API key mode. +const OPENAI_API_URL: &str = "https://api.openai.com/v1"; + +/// OAuth token refresh endpoint (same as Codex CLI). +const REFRESH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; + +/// OAuth client ID used for token refresh (same as Codex CLI). +const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; + +/// Credentials extracted from Codex's `auth.json`. +#[derive(Debug, Clone)] +pub struct CodexCredentials { + /// The bearer token (API key or ChatGPT access_token). + pub token: SecretString, + /// Whether this is a ChatGPT OAuth token (vs. an OpenAI API key). + pub is_chatgpt_mode: bool, + /// OAuth refresh token (only present in ChatGPT mode). + pub refresh_token: Option, + /// Path to the auth.json file (for persisting refreshed tokens). + pub auth_path: Option, +} + +impl CodexCredentials { + /// Returns the correct base URL for the auth mode. + /// + /// - ChatGPT mode → `https://chatgpt.com/backend-api/codex` + /// - API key mode → `https://api.openai.com/v1` + pub fn base_url(&self) -> &'static str { + if self.is_chatgpt_mode { + CHATGPT_BACKEND_URL + } else { + OPENAI_API_URL + } + } +} + +/// Partial representation of Codex's `$CODEX_HOME/auth.json`. +#[derive(Debug, Deserialize)] +struct CodexAuthJson { + auth_mode: Option, + #[serde(rename = "OPENAI_API_KEY")] + openai_api_key: Option, + tokens: Option, +} + +#[derive(Debug, Deserialize)] +struct CodexTokens { + access_token: SecretString, + refresh_token: Option, +} + +/// Request body for OAuth token refresh. +#[derive(Serialize)] +struct RefreshRequest<'a> { + client_id: &'a str, + grant_type: &'a str, + refresh_token: &'a str, +} + +/// Response from the OAuth token refresh endpoint. +#[derive(Debug, Deserialize)] +struct RefreshResponse { + access_token: SecretString, + refresh_token: Option, +} + +/// Default path used by Codex CLI: `~/.codex/auth.json`. +pub fn default_codex_auth_path() -> PathBuf { + let home_dir = dirs::home_dir().unwrap_or_else(|| { + tracing::warn!( + "Could not determine home directory; falling back to current working directory for Codex auth.json path" + ); + PathBuf::from(".") + }); + + home_dir.join(".codex").join("auth.json") +} + +/// Load credentials from a Codex `auth.json` file. +/// +/// Returns `None` if the file is missing, unreadable, or contains +/// no usable credentials. +pub fn load_codex_credentials(path: &Path) -> Option { + let content = match std::fs::read_to_string(path) { + Ok(c) => c, + Err(e) => { + tracing::debug!("Could not read Codex auth file {}: {}", path.display(), e); + return None; + } + }; + + let auth: CodexAuthJson = match serde_json::from_str(&content) { + Ok(a) => a, + Err(e) => { + tracing::warn!("Failed to parse Codex auth file {}: {}", path.display(), e); + return None; + } + }; + + let is_chatgpt = auth + .auth_mode + .as_deref() + .map(|m| m == "chatgpt" || m == "chatgptAuthTokens") + .unwrap_or(false); + + // API key mode: use OPENAI_API_KEY field. + if !is_chatgpt { + if let Some(key) = auth.openai_api_key.filter(|k| !k.is_empty()) { + tracing::info!("Loaded API key from Codex auth.json (API key mode)"); + return Some(CodexCredentials { + token: SecretString::from(key), + is_chatgpt_mode: false, + refresh_token: None, + auth_path: None, + }); + } + // If auth_mode was explicitly `apiKey`, do not fall back to checking for a token. + if auth.auth_mode.is_some() { + return None; + } + } + + // ChatGPT mode: use access_token as bearer token. + if let Some(tokens) = auth.tokens + && !tokens.access_token.expose_secret().is_empty() + { + tracing::info!( + "Loaded access token from Codex auth.json (ChatGPT mode, base_url={})", + CHATGPT_BACKEND_URL + ); + return Some(CodexCredentials { + token: tokens.access_token, + is_chatgpt_mode: true, + refresh_token: tokens.refresh_token, + auth_path: Some(path.to_path_buf()), + }); + } + + tracing::debug!( + "Codex auth.json at {} contains no usable credentials", + path.display() + ); + None +} + +/// Attempt to refresh an expired access token using the refresh token. +/// +/// On success, returns the new `access_token` and persists the refreshed +/// tokens back to `auth.json`. This follows the same OAuth protocol as +/// Codex CLI (`POST https://auth.openai.com/oauth/token`). +/// +/// Returns `None` if the refresh token is missing, the request fails, +/// or the response is malformed. +pub async fn refresh_access_token( + client: &reqwest::Client, + refresh_token: &SecretString, + auth_path: Option<&Path>, +) -> Option { + let req = RefreshRequest { + client_id: CLIENT_ID, + grant_type: "refresh_token", + refresh_token: refresh_token.expose_secret(), + }; + + tracing::info!("Attempting to refresh Codex OAuth access token"); + + let resp = match client + .post(REFRESH_TOKEN_URL) + .header("Content-Type", "application/json") + .json(&req) + .timeout(std::time::Duration::from_secs(10)) + .send() + .await + { + Ok(r) => r, + Err(e) => { + tracing::warn!("Token refresh request failed: {e}"); + return None; + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::warn!("Token refresh failed: HTTP {status}: {body}"); + if status.as_u16() == 401 { + tracing::warn!( + "Refresh token may be expired or revoked. \ + Please re-authenticate with: codex --login" + ); + } + return None; + } + + let refresh_resp: RefreshResponse = match resp.json().await { + Ok(r) => r, + Err(e) => { + tracing::warn!("Failed to parse token refresh response: {e}"); + return None; + } + }; + + let new_access_token = refresh_resp.access_token.clone(); + + // Persist refreshed tokens back to auth.json + if let Some(path) = auth_path { + if let Err(e) = persist_refreshed_tokens( + path, + refresh_resp.access_token.expose_secret(), + refresh_resp + .refresh_token + .as_ref() + .map(ExposeSecret::expose_secret), + ) { + tracing::warn!( + "Failed to persist refreshed tokens to {}: {e}", + path.display() + ); + } else { + tracing::info!("Refreshed tokens persisted to {}", path.display()); + } + } + + Some(new_access_token) +} + +/// Update `auth.json` with refreshed tokens, preserving other fields. +fn persist_refreshed_tokens( + path: &Path, + new_access_token: &str, + new_refresh_token: Option<&str>, +) -> Result<(), Box> { + let content = std::fs::read_to_string(path)?; + let mut json: serde_json::Value = serde_json::from_str(&content)?; + + if let Some(tokens) = json.get_mut("tokens") { + tokens["access_token"] = serde_json::Value::String(new_access_token.to_string()); + if let Some(rt) = new_refresh_token { + tokens["refresh_token"] = serde_json::Value::String(rt.to_string()); + } + } + + let updated = serde_json::to_string_pretty(&json)?; + let tmp_path = path.with_extension("json.tmp"); + std::fs::write(&tmp_path, updated)?; + if let Err(e) = std::fs::rename(&tmp_path, path) { + let _ = std::fs::remove_file(&tmp_path); + return Err(Box::new(e)); + } + set_auth_file_permissions(path)?; + Ok(()) +} + +#[cfg(unix)] +fn set_auth_file_permissions(path: &Path) -> Result<(), Box> { + use std::os::unix::fs::PermissionsExt; + + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; + Ok(()) +} + +#[cfg(not(unix))] +fn set_auth_file_permissions(_path: &Path) -> Result<(), Box> { + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn loads_api_key_mode() { + let mut f = NamedTempFile::new().unwrap(); + writeln!( + f, + r#"{{"auth_mode":"apiKey","OPENAI_API_KEY":"sk-test-123"}}"# + ) + .unwrap(); + let creds = load_codex_credentials(f.path()).expect("should load"); + assert_eq!(creds.token.expose_secret(), "sk-test-123"); + assert!(!creds.is_chatgpt_mode); + assert_eq!(creds.base_url(), OPENAI_API_URL); + } + + #[test] + fn loads_chatgpt_mode() { + let mut f = NamedTempFile::new().unwrap(); + writeln!( + f, + r#"{{"auth_mode":"chatgpt","tokens":{{"id_token":{{}},"access_token":"eyJ-test","refresh_token":"rt-x"}}}}"# + ) + .unwrap(); + let creds = load_codex_credentials(f.path()).expect("should load"); + assert_eq!(creds.token.expose_secret(), "eyJ-test"); + assert!(creds.is_chatgpt_mode); + assert_eq!( + creds + .refresh_token + .as_ref() + .expect("refresh token should be present") + .expose_secret(), + "rt-x" + ); + assert_eq!(creds.base_url(), CHATGPT_BACKEND_URL); + } + + #[test] + fn api_key_mode_ignores_tokens() { + let mut f = NamedTempFile::new().unwrap(); + writeln!( + f, + r#"{{"auth_mode":"apiKey","OPENAI_API_KEY":"sk-priority","tokens":{{"id_token":{{}},"access_token":"eyJ-fallback","refresh_token":"rt-x"}}}}"# + ) + .unwrap(); + let creds = load_codex_credentials(f.path()).expect("should load"); + assert_eq!(creds.token.expose_secret(), "sk-priority"); + assert!(!creds.is_chatgpt_mode); + } + + #[test] + fn returns_none_for_missing_file() { + assert!(load_codex_credentials(Path::new("/tmp/nonexistent_codex_auth.json")).is_none()); + } + + #[test] + fn returns_none_for_empty_json() { + let mut f = NamedTempFile::new().unwrap(); + writeln!(f, "{{}}").unwrap(); + assert!(load_codex_credentials(f.path()).is_none()); + } + + #[test] + fn returns_none_for_empty_key() { + let mut f = NamedTempFile::new().unwrap(); + writeln!(f, r#"{{"auth_mode":"apiKey","OPENAI_API_KEY":""}}"#).unwrap(); + assert!(load_codex_credentials(f.path()).is_none()); + } + + #[test] + fn api_key_mode_missing_key_does_not_fallback_to_chatgpt() { + // Bug: if auth_mode is "apiKey" but key is missing, the old code would + // fall through to check for a ChatGPT token, returning is_chatgpt_mode: true. + let mut f = NamedTempFile::new().unwrap(); + writeln!( + f, + r#"{{"auth_mode":"apiKey","OPENAI_API_KEY":"","tokens":{{"id_token":{{}},"access_token":"eyJ-bad","refresh_token":"rt-x"}}}}"# + ) + .unwrap(); + assert!(load_codex_credentials(f.path()).is_none()); + } +} diff --git a/src/llm/codex_chatgpt.rs b/src/llm/codex_chatgpt.rs new file mode 100644 index 0000000000..56cb337862 --- /dev/null +++ b/src/llm/codex_chatgpt.rs @@ -0,0 +1,932 @@ +//! Codex ChatGPT Responses API provider. +//! +//! Implements `LlmProvider` by speaking the OpenAI Responses API protocol +//! (`POST /responses`) used by the ChatGPT backend at +//! `chatgpt.com/backend-api/codex`. This bypasses `rig-core`'s Chat +//! Completions path, which is incompatible with this endpoint. +//! +//! # Warning +//! +//! The ChatGPT backend endpoint (`chatgpt.com/backend-api/codex`) is a +//! **private, undocumented API**. Using subscriber OAuth tokens from a +//! third-party application may violate the token's intended scope or +//! OpenAI's Terms of Service. This feature is provided as-is for +//! convenience and may break without notice. + +use async_trait::async_trait; +use eventsource_stream::Eventsource; +use futures::{Stream, StreamExt}; +use reqwest::Client; +use rust_decimal::Decimal; +use secrecy::{ExposeSecret, SecretString}; +use serde_json::{Value, json}; +use std::path::PathBuf; +use std::time::Duration; +use tokio::sync::{Mutex, RwLock}; + +use super::codex_auth; +use crate::error::LlmError; + +use super::provider::{ + ChatMessage, CompletionRequest, CompletionResponse, ContentPart, FinishReason, LlmProvider, + Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse, ToolDefinition, +}; + +/// Provider that speaks the Responses API protocol against the ChatGPT backend. +pub struct CodexChatGptProvider { + client: Client, + base_url: String, + api_key: RwLock, + /// User-configured model name (or empty/"default" for auto-detect). + configured_model: String, + /// Lazily resolved model name (populated on first LLM call). + resolved_model: tokio::sync::OnceCell, + /// OAuth refresh token for automatic 401 retry. + refresh_token: Option, + /// Path to auth.json for persisting refreshed tokens. + auth_path: Option, + /// Timeout for actual `/responses` requests. + request_timeout: Duration, + /// Prevent concurrent 401 handlers from racing the same refresh token. + refresh_lock: Mutex<()>, +} + +impl CodexChatGptProvider { + #[cfg(test)] + fn new(base_url: &str, api_key: &str, model: &str) -> Self { + Self { + client: Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + api_key: RwLock::new(SecretString::from(api_key.to_string())), + configured_model: model.to_string(), + resolved_model: tokio::sync::OnceCell::const_new(), + refresh_token: None, + auth_path: None, + request_timeout: Duration::from_secs(120), + refresh_lock: Mutex::new(()), + } + } + + /// Create a provider with lazy model detection. + /// + /// The model is **not** resolved during construction. Instead, it is + /// resolved on the first LLM call via [`resolve_model`], avoiding the + /// need for `block_in_place` / `block_on` during provider setup. + /// + /// **Model selection priority** (applied at resolution time): + /// 1. If `configured_model` is non-empty, validate it against the + /// `/models` endpoint. If it isn't in the supported list, log a + /// warning with available models and fall back to the top model. + /// 2. If `configured_model` is empty (or a generic placeholder like + /// "default"), auto-detect the highest-priority model from the API. + pub fn with_lazy_model( + base_url: &str, + api_key: SecretString, + configured_model: &str, + refresh_token: Option, + auth_path: Option, + request_timeout_secs: u64, + ) -> Self { + tracing::warn!( + "Codex ChatGPT provider uses a private, undocumented API \ + (chatgpt.com/backend-api/codex). This may violate OpenAI's \ + Terms of Service and could break without notice." + ); + + Self { + client: Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + api_key: RwLock::new(api_key), + configured_model: configured_model.to_string(), + resolved_model: tokio::sync::OnceCell::const_new(), + refresh_token, + auth_path, + request_timeout: Duration::from_secs(request_timeout_secs), + refresh_lock: Mutex::new(()), + } + } + + /// Resolve the model to use, lazily on first call. + /// + /// Uses `OnceCell` so the `/models` fetch happens at most once. + async fn resolve_model(&self) -> &str { + self.resolved_model + .get_or_init(|| async { + let api_key = self.api_key.read().await.clone(); + let available = Self::fetch_available_models(&self.client, &self.base_url, &api_key) + .await; + + let configured = &self.configured_model; + if !configured.is_empty() && configured != "default" { + // User explicitly configured a model — validate it + if available.is_empty() { + tracing::warn!( + "Could not fetch model list; using configured model '{configured}'" + ); + return configured.clone(); + } + if available.iter().any(|m| m == configured) { + tracing::info!(model = %configured, "Codex ChatGPT: using configured model"); + return configured.clone(); + } + tracing::warn!( + configured = %configured, + available = ?available, + "Configured model not found in supported list, falling back to top model" + ); + available + .into_iter() + .next() + .unwrap_or_else(|| configured.clone()) + } else { + // No user preference — auto-detect + if let Some(top) = available.into_iter().next() { + tracing::info!(model = %top, "Codex ChatGPT: auto-detected model"); + top + } else { + tracing::warn!( + "Could not auto-detect model, using fallback '{configured}'" + ); + configured.clone() + } + } + }) + .await + } + + /// Query `/models?client_version=0.111.0` and return the list of available + /// model slugs, ordered by priority (highest first). + async fn fetch_available_models( + client: &Client, + base_url: &str, + api_key: &SecretString, + ) -> Vec { + let url = format!("{base_url}/models?client_version=0.111.0"); + let resp = match client + .get(&url) + .bearer_auth(api_key.expose_secret()) + .timeout(Duration::from_secs(10)) + .send() + .await + { + Ok(r) => r, + Err(e) => { + tracing::warn!("Failed to fetch Codex models: {e}"); + return Vec::new(); + } + }; + if !resp.status().is_success() { + tracing::warn!(status = %resp.status(), "Failed to fetch Codex models"); + return Vec::new(); + } + let body: Value = match resp.json().await { + Ok(v) => v, + Err(_) => return Vec::new(), + }; + // The response has { "models": [ { "slug": "...", ... }, ... ] } + body.get("models") + .and_then(|m| m.as_array()) + .map(|models| { + models + .iter() + .filter_map(|m| { + m.get("slug") + .and_then(|s| s.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default() + } + + /// Convert IronClaw messages to Responses API request JSON. + fn build_request_body( + &self, + model: &str, + messages: &[ChatMessage], + tools: &[ToolDefinition], + tool_choice: Option<&str>, + ) -> Value { + // Extract system instructions + let instructions: String = messages + .iter() + .filter(|m| m.role == Role::System) + .map(|m| m.content.as_str()) + .collect::>() + .join("\n\n"); + + // Convert non-system messages to Responses API input items + let input: Vec = messages + .iter() + .filter(|m| m.role != Role::System) + .flat_map(Self::message_to_input_items) + .collect(); + + // Convert tool definitions + let api_tools: Vec = tools + .iter() + .map(|t| { + json!({ + "type": "function", + "name": t.name, + "description": t.description, + "parameters": t.parameters, + }) + }) + .collect(); + + let mut body = json!({ + "model": model, + "instructions": instructions, + "input": input, + "stream": true, + "store": false, + }); + + if !api_tools.is_empty() { + body["tools"] = json!(api_tools); + body["tool_choice"] = json!(tool_choice.unwrap_or("auto")); + } + + body + } + + /// Convert a single ChatMessage to one or more Responses API input items. + fn message_to_input_items(msg: &ChatMessage) -> Vec { + let mut items = Vec::new(); + + match msg.role { + Role::User => { + // Build content array: if content_parts is populated, use it + // to include multimodal content (images). Otherwise fall back + // to the plain text content field. + let content = if !msg.content_parts.is_empty() { + msg.content_parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => json!({ + "type": "input_text", + "text": text, + }), + ContentPart::ImageUrl { image_url } => json!({ + "type": "input_image", + "image_url": image_url.url, + }), + }) + .collect::>() + } else { + vec![json!({ + "type": "input_text", + "text": msg.content, + })] + }; + + items.push(json!({ + "type": "message", + "role": "user", + "content": content, + })); + } + Role::Assistant => { + // If the assistant message has tool calls, emit function_call items + if let Some(ref tool_calls) = msg.tool_calls { + // Emit the assistant text as a message if non-empty + if !msg.content.is_empty() { + items.push(json!({ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": msg.content, + }], + })); + } + for tc in tool_calls { + let args = if tc.arguments.is_string() { + tc.arguments.as_str().unwrap_or("{}").to_string() + } else { + serde_json::to_string(&tc.arguments).unwrap_or_default() + }; + items.push(json!({ + "type": "function_call", + "name": tc.name, + "arguments": args, + "call_id": tc.id, + })); + } + } else { + items.push(json!({ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": msg.content, + }], + })); + } + } + Role::Tool => { + items.push(json!({ + "type": "function_call_output", + "call_id": msg.tool_call_id.as_deref().unwrap_or(""), + "output": msg.content, + })); + } + Role::System => { + // System messages are handled via `instructions` field + } + } + + items + } + + /// Send a request and parse the SSE response. + /// + /// On HTTP 401, if a refresh token is available, attempts to refresh + /// the access token and retry the request once. + async fn send_request(&self, body: Value) -> Result { + let url = format!("{}/responses", self.base_url); + + tracing::debug!( + url = %url, + model = %body.get("model").and_then(|m| m.as_str()).unwrap_or("?"), + "Codex ChatGPT: sending request" + ); + + let api_key = self.api_key.read().await.clone(); + let resp = + Self::send_http_request(&self.client, &url, &api_key, &body, self.request_timeout) + .await?; + + let status = resp.status(); + if status.as_u16() == 401 { + // Attempt token refresh if we have a refresh token + if let Some(ref rt) = self.refresh_token { + let _refresh_guard = self.refresh_lock.lock().await; + let current_token = self.api_key.read().await.clone(); + + if current_token.expose_secret() != api_key.expose_secret() { + tracing::info!("Received 401, but another request already refreshed the token"); + let retry_resp = Self::send_http_request( + &self.client, + &url, + ¤t_token, + &body, + self.request_timeout, + ) + .await?; + let retry_status = retry_resp.status(); + if !retry_status.is_success() { + let body_text = + tokio::time::timeout(Duration::from_secs(5), retry_resp.text()) + .await + .unwrap_or(Ok(String::new())) + .unwrap_or_default(); + return Err(LlmError::RequestFailed { + provider: "codex_chatgpt".to_string(), + reason: format!( + "HTTP {retry_status} from {url} (after concurrent token refresh): {body_text}" + ), + }); + } + return Self::parse_sse_response_stream(retry_resp, self.request_timeout).await; + } + + tracing::info!("Received 401, attempting token refresh"); + if let Some(new_token) = + codex_auth::refresh_access_token(&self.client, rt, self.auth_path.as_deref()) + .await + { + // Update stored api_key + *self.api_key.write().await = new_token.clone(); + tracing::info!("Token refreshed, retrying request"); + + // Retry the request with the new token + let retry_resp = Self::send_http_request( + &self.client, + &url, + &new_token, + &body, + self.request_timeout, + ) + .await?; + + let retry_status = retry_resp.status(); + if !retry_status.is_success() { + let body_text = + tokio::time::timeout(Duration::from_secs(5), retry_resp.text()) + .await + .unwrap_or(Ok(String::new())) + .unwrap_or_default(); + return Err(LlmError::RequestFailed { + provider: "codex_chatgpt".to_string(), + reason: format!( + "HTTP {retry_status} from {url} (after token refresh): {body_text}" + ), + }); + } + + return Self::parse_sse_response_stream(retry_resp, self.request_timeout).await; + } else { + tracing::warn!( + "Token refresh failed. Please re-authenticate with: codex --login" + ); + } + } + + // No refresh token or refresh failed — return the 401 error + // Drain the response body to release the connection + let _ = resp.text().await; + return Err(LlmError::AuthFailed { + provider: "codex_chatgpt".to_string(), + }); + } + + if !status.is_success() { + // Read the error body with a timeout to avoid hanging + let body_text = tokio::time::timeout(Duration::from_secs(5), resp.text()) + .await + .unwrap_or(Ok(String::new())) + .unwrap_or_default(); + return Err(LlmError::RequestFailed { + provider: "codex_chatgpt".to_string(), + reason: format!("HTTP {status} from {url}: {body_text}",), + }); + } + + Self::parse_sse_response_stream(resp, self.request_timeout).await + } + + /// Low-level HTTP POST to the /responses endpoint. + async fn send_http_request( + client: &Client, + url: &str, + api_key: &SecretString, + body: &Value, + timeout: Duration, + ) -> Result { + client + .post(url) + .bearer_auth(api_key.expose_secret()) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(body) + .timeout(timeout) + .send() + .await + .map_err(|e| LlmError::RequestFailed { + provider: "codex_chatgpt".to_string(), + reason: format!("HTTP request failed: {e}"), + }) + } + + async fn parse_sse_response_stream( + resp: reqwest::Response, + idle_timeout: Duration, + ) -> Result { + let stream = resp + .bytes_stream() + .map(|chunk| chunk.map_err(|e| e.to_string())); + Self::parse_sse_stream(stream, idle_timeout).await + } + + async fn parse_sse_stream( + stream: S, + idle_timeout: Duration, + ) -> Result + where + S: Stream> + Unpin, + { + let mut result = ResponsesResult::default(); + let mut stream = stream.eventsource(); + + loop { + match tokio::time::timeout(idle_timeout, stream.next()).await { + Ok(Some(Ok(event))) => { + let data = event.data.trim(); + if data.is_empty() { + continue; + } + + let parsed: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(_) => continue, + }; + + if Self::handle_sse_event(&mut result, event.event.as_str(), &parsed) { + return Ok(result); + } + } + Ok(Some(Err(e))) => { + return Err(LlmError::RequestFailed { + provider: "codex_chatgpt".to_string(), + reason: format!("Failed to read SSE stream: {e}"), + }); + } + Ok(None) => return Ok(result), + Err(_) => { + return Err(LlmError::RequestFailed { + provider: "codex_chatgpt".to_string(), + reason: format!( + "Timed out waiting for SSE event after {}s", + idle_timeout.as_secs() + ), + }); + } + } + } + } + + /// Parse SSE events from the response text. + #[cfg(test)] + fn parse_sse_response(sse_text: &str) -> Result { + let mut result = ResponsesResult::default(); + let mut current_event_type = String::new(); + + for line in sse_text.lines() { + if let Some(event) = line.strip_prefix("event: ") { + current_event_type = event.trim().to_string(); + continue; + } + + if let Some(data) = line.strip_prefix("data: ") { + let data = data.trim(); + if data.is_empty() { + continue; + } + + let parsed: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(_) => continue, + }; + + if Self::handle_sse_event(&mut result, current_event_type.as_str(), &parsed) { + return Ok(result); + } + } + } + + Ok(result) + } + + fn handle_sse_event(result: &mut ResponsesResult, event_type: &str, parsed: &Value) -> bool { + match event_type { + "response.output_text.delta" => { + if let Some(delta) = parsed.get("delta").and_then(|d| d.as_str()) { + result.text.push_str(delta); + } + } + "response.output_item.added" => { + // Capture function call metadata when the item is first added. + // The item has: id (item_id), call_id, name, type. + let item = parsed.get("item").unwrap_or(parsed); + if item.get("type").and_then(|t| t.as_str()) == Some("function_call") { + let item_id = item + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let call_id = item + .get("call_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let name = item + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + result + .pending_tool_calls + .entry(item_id) + .or_insert_with(|| PendingToolCall { + call_id, + name, + arguments: String::new(), + }); + } + } + "response.function_call_arguments.delta" => { + // Delta events use `item_id` (not `call_id`) + if let Some(item_id) = parsed.get("item_id").and_then(|v| v.as_str()) + && let Some(entry) = result.pending_tool_calls.get_mut(item_id) + && let Some(delta) = parsed.get("delta").and_then(|d| d.as_str()) + { + entry.arguments.push_str(delta); + } + } + "response.completed" => { + if let Some(response) = parsed.get("response") + && let Some(usage) = response.get("usage") + { + result.input_tokens = usage + .get("input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + result.output_tokens = usage + .get("output_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + } + return true; + } + _ => {} + } + + false + } + + /// Remove keys with empty-string values from a JSON object. + /// + /// gpt-5.2-codex fills optional tool parameters with `""` (e.g. + /// `"timestamp": ""`). IronClaw's tool validation treats these as + /// invalid "non-empty input expected". Stripping them makes the + /// tool see only the actually-provided values. + fn strip_empty_string_values(value: Value) -> Value { + match value { + Value::Object(map) => { + let cleaned: serde_json::Map = map + .into_iter() + .filter(|(_, v)| !matches!(v, Value::String(s) if s.is_empty())) + .map(|(k, v)| (k, Self::strip_empty_string_values(v))) + .collect(); + Value::Object(cleaned) + } + other => other, + } + } +} + +#[derive(Debug, Default)] +struct ResponsesResult { + text: String, + /// Keyed by item_id (the SSE item identifier, e.g. "fc_..."). + pending_tool_calls: std::collections::HashMap, + input_tokens: u32, + output_tokens: u32, +} + +#[derive(Debug)] +struct PendingToolCall { + /// The call_id from the API (e.g. "call_..."), used to match results. + call_id: String, + name: String, + arguments: String, +} + +#[async_trait] +impl LlmProvider for CodexChatGptProvider { + fn model_name(&self) -> &str { + // Return resolved model if available, otherwise the configured name. + self.resolved_model + .get() + .map(|s| s.as_str()) + .unwrap_or(&self.configured_model) + } + + fn cost_per_token(&self) -> (Decimal, Decimal) { + // ChatGPT backend doesn't expose per-token pricing + (Decimal::ZERO, Decimal::ZERO) + } + + async fn complete(&self, request: CompletionRequest) -> Result { + let model = self.resolve_model().await; + let body = self.build_request_body(model, &request.messages, &[], None); + let result = self.send_request(body).await?; + + Ok(CompletionResponse { + content: result.text, + input_tokens: result.input_tokens, + output_tokens: result.output_tokens, + finish_reason: FinishReason::Stop, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }) + } + + async fn complete_with_tools( + &self, + request: ToolCompletionRequest, + ) -> Result { + let model = self.resolve_model().await; + let body = self.build_request_body( + model, + &request.messages, + &request.tools, + request.tool_choice.as_deref(), + ); + let result = self.send_request(body).await?; + + let tool_calls: Vec = result + .pending_tool_calls + .into_values() + .map(|tc| { + let args: Value = + serde_json::from_str(&tc.arguments).unwrap_or_else(|_| json!(tc.arguments)); + // gpt-5.2-codex fills optional parameters with empty strings (e.g. + // `"timestamp": ""`), which IronClaw's tool validation rejects. + // Strip them so only actually-provided values reach the tool. + let args = Self::strip_empty_string_values(args); + ToolCall { + id: tc.call_id, + name: tc.name, + arguments: args, + } + }) + .collect(); + + let finish_reason = if tool_calls.is_empty() { + FinishReason::Stop + } else { + FinishReason::ToolUse + }; + + Ok(ToolCompletionResponse { + content: if result.text.is_empty() { + None + } else { + Some(result.text) + }, + tool_calls, + input_tokens: result.input_tokens, + output_tokens: result.output_tokens, + finish_reason, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use futures::stream; + + #[test] + fn test_message_conversion_user() { + let items = CodexChatGptProvider::message_to_input_items(&ChatMessage::user("hello")); + assert_eq!(items.len(), 1); + assert_eq!(items[0]["type"], "message"); + assert_eq!(items[0]["role"], "user"); + assert_eq!(items[0]["content"][0]["type"], "input_text"); + assert_eq!(items[0]["content"][0]["text"], "hello"); + } + + #[test] + fn test_message_conversion_user_with_image() { + use super::super::provider::ImageUrl; + let parts = vec![ + ContentPart::Text { + text: "What's in this image?".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "data:image/png;base64,iVBOR...".to_string(), + detail: None, + }, + }, + ]; + let msg = ChatMessage::user_with_parts("", parts); + let items = CodexChatGptProvider::message_to_input_items(&msg); + assert_eq!(items.len(), 1); + assert_eq!(items[0]["type"], "message"); + assert_eq!(items[0]["role"], "user"); + let content = items[0]["content"].as_array().unwrap(); + assert_eq!(content.len(), 2); + assert_eq!(content[0]["type"], "input_text"); + assert_eq!(content[0]["text"], "What's in this image?"); + assert_eq!(content[1]["type"], "input_image"); + assert_eq!(content[1]["image_url"], "data:image/png;base64,iVBOR..."); + } + #[test] + fn test_message_conversion_assistant() { + let items = CodexChatGptProvider::message_to_input_items(&ChatMessage::assistant("hi")); + assert_eq!(items.len(), 1); + assert_eq!(items[0]["type"], "message"); + assert_eq!(items[0]["role"], "assistant"); + assert_eq!(items[0]["content"][0]["type"], "output_text"); + } + + #[test] + fn test_message_conversion_tool_result() { + let msg = ChatMessage::tool_result("call_1", "search", "result text"); + let items = CodexChatGptProvider::message_to_input_items(&msg); + assert_eq!(items.len(), 1); + assert_eq!(items[0]["type"], "function_call_output"); + assert_eq!(items[0]["call_id"], "call_1"); + assert_eq!(items[0]["output"], "result text"); + } + + #[test] + fn test_message_conversion_assistant_with_tool_calls() { + let tc = ToolCall { + id: "call_1".to_string(), + name: "search".to_string(), + arguments: json!({"query": "rust"}), + }; + let msg = ChatMessage::assistant_with_tool_calls(Some("thinking...".into()), vec![tc]); + let items = CodexChatGptProvider::message_to_input_items(&msg); + // Should produce: 1 text message + 1 function_call + assert_eq!(items.len(), 2); + assert_eq!(items[0]["type"], "message"); + assert_eq!(items[1]["type"], "function_call"); + assert_eq!(items[1]["name"], "search"); + assert_eq!(items[1]["call_id"], "call_1"); + } + + #[test] + fn test_build_request_extracts_system_as_instructions() { + let provider = CodexChatGptProvider::new("https://example.com", "key", "gpt-4o"); + let messages = vec![ + ChatMessage::system("You are helpful."), + ChatMessage::user("hello"), + ]; + let body = provider.build_request_body("gpt-4o", &messages, &[], None); + assert_eq!(body["instructions"], "You are helpful."); + // input should only contain the user message, not the system message + assert_eq!(body["input"].as_array().unwrap().len(), 1); + // store must be false for ChatGPT backend + assert_eq!(body["store"], false); + } + + #[test] + fn test_parse_sse_text_response() { + let sse = r#"event: response.output_text.delta +data: {"delta":"Hello"} + +event: response.output_text.delta +data: {"delta":" world!"} + +event: response.completed +data: {"response":{"usage":{"input_tokens":10,"output_tokens":5}}} + +"#; + let result = CodexChatGptProvider::parse_sse_response(sse).unwrap(); + assert_eq!(result.text, "Hello world!"); + assert_eq!(result.input_tokens, 10); + assert_eq!(result.output_tokens, 5); + assert!(result.pending_tool_calls.is_empty()); + } + + #[test] + fn test_parse_sse_tool_call() { + // Real API format: output_item.added has item.id (item_id) + item.call_id, + // delta events use item_id (not call_id) + let sse = r#"event: response.output_item.added +data: {"item":{"id":"fc_1","type":"function_call","call_id":"call_1","name":"search"}} + +event: response.function_call_arguments.delta +data: {"item_id":"fc_1","delta":"{\"query\":"} + +event: response.function_call_arguments.delta +data: {"item_id":"fc_1","delta":"\"rust\"}"} + +event: response.completed +data: {"response":{"usage":{"input_tokens":20,"output_tokens":15}}} + +"#; + let result = CodexChatGptProvider::parse_sse_response(sse).unwrap(); + assert!(result.text.is_empty()); + assert_eq!(result.pending_tool_calls.len(), 1); + let tc = result.pending_tool_calls.get("fc_1").unwrap(); + assert_eq!(tc.call_id, "call_1"); + assert_eq!(tc.name, "search"); + assert_eq!(tc.arguments, "{\"query\":\"rust\"}"); + } + + #[tokio::test] + async fn test_parse_sse_stream_response() { + let stream = stream::iter(vec![ + Ok(Bytes::from_static( + b"event: response.output_text.delta\ndata: {\"delta\":\"Hello\"}\n\n", + )), + Ok(Bytes::from_static( + b"event: response.output_text.delta\ndata: {\"delta\":\" world\"}\n\n", + )), + Ok(Bytes::from_static( + b"event: response.completed\ndata: {\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}}\n\n", + )), + ]); + + let result = CodexChatGptProvider::parse_sse_stream(stream, Duration::from_secs(1)) + .await + .unwrap(); + assert_eq!(result.text, "Hello world"); + assert_eq!(result.input_tokens, 3); + assert_eq!(result.output_tokens, 2); + } + + #[test] + fn test_strip_empty_string_values() { + let input = json!({ + "format": "%Y-%m-%d", + "operation": "now", + "timestamp": "", + "timestamp2": "", + }); + let cleaned = CodexChatGptProvider::strip_empty_string_values(input); + assert_eq!(cleaned, json!({"format": "%Y-%m-%d", "operation": "now"})); + } +} diff --git a/src/llm/config.rs b/src/llm/config.rs index 1902f128b3..8b7d41c3c8 100644 --- a/src/llm/config.rs +++ b/src/llm/config.rs @@ -5,6 +5,8 @@ //! extracted into a standalone crate. Resolution logic (reading env vars, //! settings) lives in `crate::config::llm`. +use std::path::PathBuf; + use secrecy::SecretString; use crate::llm::registry::ProviderProtocol; @@ -85,6 +87,13 @@ pub struct RegistryProviderConfig { /// OAuth token for providers that support Bearer auth (e.g. Anthropic via `claude login`). /// When set, the provider factory routes to the OAuth-specific provider implementation. pub oauth_token: Option, + /// When true, route OpenAI-compatible traffic to the Codex ChatGPT + /// Responses API provider instead of rig-core's Chat Completions path. + pub is_codex_chatgpt: bool, + /// OAuth refresh token for Codex ChatGPT token refresh. + pub refresh_token: Option, + /// Path to Codex auth.json for persisting refreshed tokens. + pub auth_path: Option, /// Prompt cache retention (Anthropic-specific). pub cache_retention: CacheRetention, /// Parameter names that this provider does not support (e.g., `["temperature"]`). @@ -163,3 +172,42 @@ pub struct NearAiConfig { /// Enable cascade mode for smart routing. Default: true. pub smart_routing_cascade: bool, } + +impl NearAiConfig { + /// Create a minimal config suitable for listing available models. + /// + /// Reads `NEARAI_API_KEY` from the environment and selects the + /// appropriate base URL (cloud-api when API key is present, + /// private.near.ai for session-token auth). + pub(crate) fn for_model_discovery() -> Self { + let api_key = std::env::var("NEARAI_API_KEY") + .ok() + .filter(|k| !k.is_empty()) + .map(SecretString::from); + + let default_base = if api_key.is_some() { + "https://cloud-api.near.ai" + } else { + "https://private.near.ai" + }; + let base_url = + std::env::var("NEARAI_BASE_URL").unwrap_or_else(|_| default_base.to_string()); + + Self { + model: String::new(), + cheap_model: None, + base_url, + api_key, + fallback_model: None, + max_retries: 3, + circuit_breaker_threshold: None, + circuit_breaker_recovery_secs: 30, + response_cache_enabled: false, + response_cache_ttl_secs: 3600, + response_cache_max_entries: 1000, + failover_cooldown_secs: 300, + failover_cooldown_threshold: 3, + smart_routing_cascade: true, + } + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index b49e4974a1..51309bf37d 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -12,6 +12,8 @@ mod anthropic_oauth; #[cfg(feature = "bedrock")] mod bedrock; pub mod circuit_breaker; +pub(crate) mod codex_auth; +mod codex_chatgpt; pub mod config; pub mod costs; pub mod error; @@ -29,6 +31,7 @@ pub mod session; pub mod smart_routing; pub mod image_models; +pub mod models; pub mod reasoning_models; pub mod vision_models; @@ -101,7 +104,7 @@ pub async fn create_llm_provider( provider: config.backend.clone(), })?; - create_registry_provider(reg_config) + create_registry_provider(reg_config, timeout) } /// Create an LLM provider from a `NearAiConfig` directly. @@ -139,7 +142,13 @@ pub fn create_llm_provider_with_config( /// `create_*_provider` functions. fn create_registry_provider( config: &RegistryProviderConfig, + request_timeout_secs: u64, ) -> Result, LlmError> { + // Codex ChatGPT mode: use the Responses API provider + if config.is_codex_chatgpt { + return create_codex_chatgpt_from_registry(config, request_timeout_secs); + } + match config.protocol { ProviderProtocol::OpenAiCompletions => create_openai_compat_from_registry(config), ProviderProtocol::Anthropic => create_anthropic_from_registry(config), @@ -147,6 +156,36 @@ fn create_registry_provider( } } +fn create_codex_chatgpt_from_registry( + config: &RegistryProviderConfig, + request_timeout_secs: u64, +) -> Result, LlmError> { + let api_key = config + .api_key + .as_ref() + .cloned() + .ok_or_else(|| LlmError::AuthFailed { + provider: "codex_chatgpt".to_string(), + })?; + + tracing::info!( + configured_model = %config.model, + base_url = %config.base_url, + "Using Codex ChatGPT provider (Responses API) — model detection deferred to first call" + ); + + let provider = codex_chatgpt::CodexChatGptProvider::with_lazy_model( + &config.base_url, + api_key, + &config.model, + config.refresh_token.clone(), + config.auth_path.clone(), + request_timeout_secs, + ); + + Ok(Arc::new(provider)) +} + #[cfg(feature = "bedrock")] async fn create_bedrock_provider(config: &LlmConfig) -> Result, LlmError> { let br = config @@ -162,6 +201,7 @@ async fn create_bedrock_provider(config: &LlmConfig) -> Result) -> Vec<(String, String)> { + let static_defaults = vec![ + ( + "claude-opus-4-6".into(), + "Claude Opus 4.6 (latest flagship)".into(), + ), + ("claude-sonnet-4-6".into(), "Claude Sonnet 4.6".into()), + ("claude-opus-4-5".into(), "Claude Opus 4.5".into()), + ("claude-sonnet-4-5".into(), "Claude Sonnet 4.5".into()), + ("claude-haiku-4-5".into(), "Claude Haiku 4.5 (fast)".into()), + ]; + + let api_key = cached_key + .map(String::from) + .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok()) + .filter(|k| !k.is_empty() && k != crate::config::OAUTH_PLACEHOLDER); + + // Fall back to OAuth token if no API key + let oauth_token = if api_key.is_none() { + crate::config::helpers::optional_env("ANTHROPIC_OAUTH_TOKEN") + .ok() + .flatten() + .filter(|t| !t.is_empty()) + } else { + None + }; + + let (key_or_token, is_oauth) = match (api_key, oauth_token) { + (Some(k), _) => (k, false), + (None, Some(t)) => (t, true), + (None, None) => return static_defaults, + }; + + let client = reqwest::Client::new(); + let mut request = client + .get("https://api.anthropic.com/v1/models") + .header("anthropic-version", "2023-06-01") + .timeout(std::time::Duration::from_secs(5)); + + if is_oauth { + request = request + .bearer_auth(&key_or_token) + .header("anthropic-beta", "oauth-2025-04-20"); + } else { + request = request.header("x-api-key", &key_or_token); + } + + let resp = match request.send().await { + Ok(r) if r.status().is_success() => r, + _ => return static_defaults, + }; + + #[derive(serde::Deserialize)] + struct ModelEntry { + id: String, + } + #[derive(serde::Deserialize)] + struct ModelsResponse { + data: Vec, + } + + match resp.json::().await { + Ok(body) => { + let mut models: Vec<(String, String)> = body + .data + .into_iter() + .filter(|m| !m.id.contains("embedding") && !m.id.contains("audio")) + .map(|m| { + let label = m.id.clone(); + (m.id, label) + }) + .collect(); + if models.is_empty() { + return static_defaults; + } + models.sort_by(|a, b| a.0.cmp(&b.0)); + models + } + Err(_) => static_defaults, + } +} + +/// Fetch models from the OpenAI API. +/// +/// Returns `(model_id, display_label)` pairs. Falls back to static defaults on error. +pub(crate) async fn fetch_openai_models(cached_key: Option<&str>) -> Vec<(String, String)> { + let static_defaults = vec![ + ( + "gpt-5.3-codex".into(), + "GPT-5.3 Codex (latest flagship)".into(), + ), + ("gpt-5.2-codex".into(), "GPT-5.2 Codex".into()), + ("gpt-5.2".into(), "GPT-5.2".into()), + ( + "gpt-5.1-codex-mini".into(), + "GPT-5.1 Codex Mini (fast)".into(), + ), + ("gpt-5".into(), "GPT-5".into()), + ("gpt-5-mini".into(), "GPT-5 Mini".into()), + ("gpt-4.1".into(), "GPT-4.1".into()), + ("gpt-4.1-mini".into(), "GPT-4.1 Mini".into()), + ("o4-mini".into(), "o4-mini (fast reasoning)".into()), + ("o3".into(), "o3 (reasoning)".into()), + ]; + + let api_key = cached_key + .map(String::from) + .or_else(|| std::env::var("OPENAI_API_KEY").ok()) + .filter(|k| !k.is_empty()); + + let api_key = match api_key { + Some(k) => k, + None => return static_defaults, + }; + + let client = reqwest::Client::new(); + let resp = match client + .get("https://api.openai.com/v1/models") + .bearer_auth(&api_key) + .timeout(std::time::Duration::from_secs(5)) + .send() + .await + { + Ok(r) if r.status().is_success() => r, + _ => return static_defaults, + }; + + #[derive(serde::Deserialize)] + struct ModelEntry { + id: String, + } + #[derive(serde::Deserialize)] + struct ModelsResponse { + data: Vec, + } + + match resp.json::().await { + Ok(body) => { + let mut models: Vec<(String, String)> = body + .data + .into_iter() + .filter(|m| is_openai_chat_model(&m.id)) + .map(|m| { + let label = m.id.clone(); + (m.id, label) + }) + .collect(); + if models.is_empty() { + return static_defaults; + } + sort_openai_models(&mut models); + models + } + Err(_) => static_defaults, + } +} + +pub(crate) fn is_openai_chat_model(model_id: &str) -> bool { + let id = model_id.to_ascii_lowercase(); + + let is_chat_family = id.starts_with("gpt-") + || id.starts_with("chatgpt-") + || id.starts_with("o1") + || id.starts_with("o3") + || id.starts_with("o4") + || id.starts_with("o5"); + + let is_non_chat_variant = id.contains("realtime") + || id.contains("audio") + || id.contains("transcribe") + || id.contains("tts") + || id.contains("embedding") + || id.contains("moderation") + || id.contains("image"); + + is_chat_family && !is_non_chat_variant +} + +pub(crate) fn openai_model_priority(model_id: &str) -> usize { + let id = model_id.to_ascii_lowercase(); + + const EXACT_PRIORITY: &[&str] = &[ + "gpt-5.3-codex", + "gpt-5.2-codex", + "gpt-5.2", + "gpt-5.1-codex-mini", + "gpt-5", + "gpt-5-mini", + "gpt-5-nano", + "o4-mini", + "o3", + "o1", + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4o", + "gpt-4o-mini", + ]; + if let Some(pos) = EXACT_PRIORITY.iter().position(|m| id == *m) { + return pos; + } + + const PREFIX_PRIORITY: &[&str] = &[ + "gpt-5.", "gpt-5-", "o3-", "o4-", "o1-", "gpt-4.1-", "gpt-4o-", "gpt-3.5-", "chatgpt-", + ]; + if let Some(pos) = PREFIX_PRIORITY + .iter() + .position(|prefix| id.starts_with(prefix)) + { + return EXACT_PRIORITY.len() + pos; + } + + EXACT_PRIORITY.len() + PREFIX_PRIORITY.len() + 1 +} + +pub(crate) fn sort_openai_models(models: &mut [(String, String)]) { + models.sort_by(|a, b| { + openai_model_priority(&a.0) + .cmp(&openai_model_priority(&b.0)) + .then_with(|| a.0.cmp(&b.0)) + }); +} + +/// Fetch installed models from a local Ollama instance. +/// +/// Returns `(model_name, display_label)` pairs. Falls back to static defaults on error. +pub(crate) async fn fetch_ollama_models(base_url: &str) -> Vec<(String, String)> { + let static_defaults = vec![ + ("llama3".into(), "llama3".into()), + ("mistral".into(), "mistral".into()), + ("codellama".into(), "codellama".into()), + ]; + + let url = format!("{}/api/tags", base_url.trim_end_matches('/')); + let client = reqwest::Client::new(); + + let resp = match client + .get(&url) + .timeout(std::time::Duration::from_secs(5)) + .send() + .await + { + Ok(r) if r.status().is_success() => r, + Ok(_) => return static_defaults, + Err(_) => { + tracing::warn!( + "Could not connect to Ollama at {base_url}. Is it running? Using static defaults." + ); + return static_defaults; + } + }; + + #[derive(serde::Deserialize)] + struct ModelEntry { + name: String, + } + #[derive(serde::Deserialize)] + struct TagsResponse { + models: Vec, + } + + match resp.json::().await { + Ok(body) => { + let models: Vec<(String, String)> = body + .models + .into_iter() + .map(|m| { + let label = m.name.clone(); + (m.name, label) + }) + .collect(); + if models.is_empty() { + return static_defaults; + } + models + } + Err(_) => static_defaults, + } +} + +/// Fetch models from a generic OpenAI-compatible /v1/models endpoint. +/// +/// Used for registry providers like Groq, NVIDIA NIM, etc. +pub(crate) async fn fetch_openai_compatible_models( + base_url: &str, + cached_key: Option<&str>, +) -> Vec<(String, String)> { + if base_url.is_empty() { + return vec![]; + } + + let url = format!("{}/models", base_url.trim_end_matches('/')); + let client = reqwest::Client::new(); + let mut req = client.get(&url).timeout(std::time::Duration::from_secs(5)); + if let Some(key) = cached_key { + req = req.bearer_auth(key); + } + + let resp = match req.send().await { + Ok(r) if r.status().is_success() => r, + _ => return vec![], + }; + + #[derive(serde::Deserialize)] + struct Model { + id: String, + } + #[derive(serde::Deserialize)] + struct ModelsResponse { + data: Vec, + } + + match resp.json::().await { + Ok(body) => body + .data + .into_iter() + .map(|m| { + let label = m.id.clone(); + (m.id, label) + }) + .collect(), + Err(_) => vec![], + } +} + +/// Build the `LlmConfig` used by `fetch_nearai_models` to list available models. +/// +/// Uses [`NearAiConfig::for_model_discovery()`] to construct a minimal NEAR AI +/// config, then wraps it in an `LlmConfig` with session config for auth. +pub(crate) fn build_nearai_model_fetch_config() -> crate::config::LlmConfig { + let auth_base_url = + std::env::var("NEARAI_AUTH_URL").unwrap_or_else(|_| "https://private.near.ai".to_string()); + + crate::config::LlmConfig { + backend: "nearai".to_string(), + session: crate::llm::session::SessionConfig { + auth_base_url, + session_path: crate::config::llm::default_session_path(), + }, + nearai: crate::config::NearAiConfig::for_model_discovery(), + provider: None, + bedrock: None, + request_timeout_secs: 120, + } +} diff --git a/src/llm/rig_adapter.rs b/src/llm/rig_adapter.rs index 41724c319e..5c1faef79f 100644 --- a/src/llm/rig_adapter.rs +++ b/src/llm/rig_adapter.rs @@ -357,15 +357,31 @@ fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec { - // Tool result message: wrap as User { ToolResult } + // Tool result message: wrap as User { ToolResult }. + // Merge consecutive tool results into a single User message + // so the API sees one multi-result message instead of + // multiple consecutive User messages (which Anthropic rejects). let tool_id = normalized_tool_call_id(msg.tool_call_id.as_deref(), history.len()); - history.push(RigMessage::User { - content: OneOrMany::one(UserContent::ToolResult(RigToolResult { - id: tool_id.clone(), - call_id: Some(tool_id), - content: OneOrMany::one(ToolResultContent::text(&msg.content)), - })), + let tool_result = UserContent::ToolResult(RigToolResult { + id: tool_id.clone(), + call_id: Some(tool_id), + content: OneOrMany::one(ToolResultContent::text(&msg.content)), }); + + let should_merge = matches!( + history.last(), + Some(RigMessage::User { content }) if content.iter().all(|c| matches!(c, UserContent::ToolResult(_))) + ); + + if should_merge { + if let Some(RigMessage::User { content }) = history.last_mut() { + content.push(tool_result); + } + } else { + history.push(RigMessage::User { + content: OneOrMany::one(tool_result), + }); + } } } } @@ -1280,4 +1296,68 @@ mod tests { assert!(adapter.unsupported_params.is_empty()); } + + /// Regression test: consecutive tool_result messages from parallel tool + /// execution must be merged into a single User message with multiple + /// ToolResult content items. Without merging, APIs like Anthropic reject + /// the request due to consecutive User messages. + #[test] + fn test_consecutive_tool_results_merged_into_single_user_message() { + let tc1 = IronToolCall { + id: "call_a".to_string(), + name: "search".to_string(), + arguments: serde_json::json!({"q": "rust"}), + }; + let tc2 = IronToolCall { + id: "call_b".to_string(), + name: "fetch".to_string(), + arguments: serde_json::json!({"url": "https://example.com"}), + }; + let assistant = ChatMessage::assistant_with_tool_calls(None, vec![tc1, tc2]); + let result_a = ChatMessage::tool_result("call_a", "search", "search results"); + let result_b = ChatMessage::tool_result("call_b", "fetch", "fetch results"); + + let messages = vec![assistant, result_a, result_b]; + let (_preamble, history) = convert_messages(&messages); + + // Should be: 1 assistant + 1 merged user (not 1 assistant + 2 users) + assert_eq!( + history.len(), + 2, + "Expected 2 messages (assistant + merged user), got {}", + history.len() + ); + + // The second message should contain both tool results + match &history[1] { + RigMessage::User { content } => { + assert_eq!( + content.len(), + 2, + "Expected 2 tool results in merged user message, got {}", + content.len() + ); + for item in content.iter() { + assert!( + matches!(item, UserContent::ToolResult(_)), + "Expected ToolResult content" + ); + } + } + other => panic!("Expected User message, got: {:?}", other), + } + } + + /// Verify that a tool_result after a non-tool User message is NOT merged. + #[test] + fn test_tool_result_after_user_text_not_merged() { + let user_msg = ChatMessage::user("hello"); + let tool_msg = ChatMessage::tool_result("call_1", "search", "results"); + + let messages = vec![user_msg, tool_msg]; + let (_preamble, history) = convert_messages(&messages); + + // Should be 2 separate User messages (text user + tool result user) + assert_eq!(history.len(), 2); + } } diff --git a/src/main.rs b/src/main.rs index 0b4695305a..574616772d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,6 +27,8 @@ use ironclaw::{ webhooks::{self, ToolWebhookState}, }; +#[cfg(unix)] +use ironclaw::channels::ChannelSecretUpdater; #[cfg(any(feature = "postgres", feature = "libsql"))] use ironclaw::setup::{SetupConfig, SetupWizard}; @@ -521,6 +523,30 @@ async fn async_main() -> anyhow::Result<()> { } } + // Persist auto-generated auth token so it survives restarts. + // Write to the "default" settings namespace, which is the namespace + // Config::from_db() reads from — NOT the gateway channel's user_id. + if gw_config.auth_token.is_none() { + let token_to_persist = gw.auth_token().to_string(); + if let Some(ref db) = components.db { + let db = db.clone(); + tokio::spawn(async move { + if let Err(e) = db + .set_setting( + "default", + "channels.gateway_auth_token", + &serde_json::Value::String(token_to_persist), + ) + .await + { + tracing::warn!("Failed to persist auto-generated gateway auth token: {e}"); + } else { + tracing::debug!("Persisted auto-generated gateway auth token to settings"); + } + }); + } + } + gateway_url = Some(format!( "http://{}:{}/?token={}", gw_config.host, @@ -740,7 +766,6 @@ async fn async_main() -> anyhow::Result<()> { #[cfg(unix)] { - use ironclaw::channels::ChannelSecretUpdater; // Collect all channels that support secret updates let mut secret_updaters: Vec> = Vec::new(); if let Some(ref state) = http_channel_state { diff --git a/src/sandbox/manager.rs b/src/sandbox/manager.rs index ce709f5081..1c0decc842 100644 --- a/src/sandbox/manager.rs +++ b/src/sandbox/manager.rs @@ -236,14 +236,59 @@ impl SandboxManager { self.initialize().await?; } - // Get proxy port if running + // Retry transient container failures (Docker daemon glitches, container + // creation races) up to MAX_SANDBOX_RETRIES times with exponential backoff. + const MAX_SANDBOX_RETRIES: u32 = 2; + let mut last_err: Option = None; + + for attempt in 0..=MAX_SANDBOX_RETRIES { + if attempt > 0 { + let delay = std::time::Duration::from_secs(1 << attempt); // 2s, 4s + tracing::warn!( + attempt = attempt + 1, + max_attempts = MAX_SANDBOX_RETRIES + 1, + delay_secs = delay.as_secs(), + "Retrying sandbox execution after transient failure" + ); + tokio::time::sleep(delay).await; + } + + match self + .try_execute_in_container(command, cwd, policy, env.clone()) + .await + { + Ok(output) => return Ok(output), + Err(e) if is_transient_sandbox_error(&e) => { + tracing::warn!( + attempt = attempt + 1, + error = %e, + "Transient sandbox error, will retry" + ); + last_err = Some(e); + } + Err(e) => return Err(e), + } + } + + Err(last_err.unwrap_or_else(|| SandboxError::ExecutionFailed { + reason: "all retry attempts exhausted".to_string(), + })) + } + + /// Single attempt at container execution (no retry logic). + async fn try_execute_in_container( + &self, + command: &str, + cwd: &Path, + policy: SandboxPolicy, + env: HashMap, + ) -> Result { let proxy_port = if let Some(proxy) = self.proxy.read().await.as_ref() { proxy.addr().await.map(|a| a.port()).unwrap_or(0) } else { 0 }; - // Reuse the stored Docker connection, create a runner with the current proxy port let docker = self.docker .read() @@ -262,7 +307,6 @@ impl SandboxManager { }; let container_output = runner.execute(command, cwd, policy, &limits, env).await?; - Ok(container_output.into()) } @@ -373,6 +417,20 @@ impl Drop for SandboxManager { } } +/// Check whether a sandbox error is transient and worth retrying. +/// +/// Transient errors are those caused by Docker daemon glitches, container +/// creation race conditions, or container start failures — not by command +/// execution failures, timeouts, or policy violations. +fn is_transient_sandbox_error(err: &SandboxError) -> bool { + matches!( + err, + SandboxError::DockerNotAvailable { .. } + | SandboxError::ContainerCreationFailed { .. } + | SandboxError::ContainerStartFailed { .. } + ) +} + /// Builder for creating a sandbox manager. pub struct SandboxManagerBuilder { config: SandboxConfig, @@ -597,4 +655,43 @@ mod tests { assert!(output.truncated); assert!(output.stdout.len() <= 32 * 1024); } + + #[test] + fn transient_errors_are_retryable() { + assert!(super::is_transient_sandbox_error( + &SandboxError::DockerNotAvailable { + reason: "daemon restarting".to_string() + } + )); + assert!(super::is_transient_sandbox_error( + &SandboxError::ContainerCreationFailed { + reason: "image pull glitch".to_string() + } + )); + assert!(super::is_transient_sandbox_error( + &SandboxError::ContainerStartFailed { + reason: "cgroup race".to_string() + } + )); + } + + #[test] + fn non_transient_errors_are_not_retryable() { + assert!(!super::is_transient_sandbox_error(&SandboxError::Timeout( + std::time::Duration::from_secs(30) + ))); + assert!(!super::is_transient_sandbox_error( + &SandboxError::ExecutionFailed { + reason: "exit code 1".to_string() + } + )); + assert!(!super::is_transient_sandbox_error( + &SandboxError::NetworkBlocked { + reason: "policy violation".to_string() + } + )); + assert!(!super::is_transient_sandbox_error(&SandboxError::Config { + reason: "bad config".to_string() + })); + } } diff --git a/src/secrets/mod.rs b/src/secrets/mod.rs index 9ebad71598..9154b78b49 100644 --- a/src/secrets/mod.rs +++ b/src/secrets/mod.rs @@ -109,3 +109,59 @@ pub fn create_secrets_store( store } + +/// Try to resolve an existing master key from env var or OS keychain. +/// +/// Resolution order: +/// 1. `SECRETS_MASTER_KEY` environment variable (hex-encoded) +/// 2. OS keychain (macOS Keychain / Linux secret-service) +/// +/// Returns `None` if no key is available (caller should generate one). +pub async fn resolve_master_key() -> Option { + // 1. Check env var + if let Ok(env_key) = std::env::var("SECRETS_MASTER_KEY") + && !env_key.is_empty() + { + return Some(env_key); + } + + // 2. Try OS keychain + if let Ok(keychain_key_bytes) = keychain::get_master_key().await { + let key_hex: String = keychain_key_bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect(); + return Some(key_hex); + } + + None +} + +/// Create a `SecretsCrypto` from a master key string. +/// +/// The key is typically hex-encoded (from `generate_master_key_hex` or +/// the `SECRETS_MASTER_KEY` env var), but `SecretsCrypto::new` validates +/// only key length, not encoding. Any sufficiently long string works. +pub fn crypto_from_hex(hex: &str) -> Result, SecretError> { + let crypto = SecretsCrypto::new(secrecy::SecretString::from(hex.to_string()))?; + Ok(std::sync::Arc::new(crypto)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_crypto_from_hex_valid() { + // 32 bytes = 64 hex chars + let hex = "0123456789abcdef".repeat(4); // 64 hex chars + let result = crypto_from_hex(&hex); + assert!(result.is_ok()); // safety: test assertion + } + + #[test] + fn test_crypto_from_hex_invalid() { + let result = crypto_from_hex("too_short"); + assert!(result.is_err()); // safety: test assertion + } +} diff --git a/src/settings.rs b/src/settings.rs index 29bfbae169..2a5b6bbd21 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -360,6 +360,10 @@ pub struct HeartbeatSettings { #[serde(default)] pub notify_user: Option, + /// Fixed time-of-day to fire (HH:MM, 24h). When set, interval_secs is ignored. + #[serde(default)] + pub fire_at: Option, + /// Hour (0-23) when quiet hours start (heartbeat skipped). #[serde(default)] pub quiet_hours_start: Option, @@ -368,7 +372,7 @@ pub struct HeartbeatSettings { #[serde(default)] pub quiet_hours_end: Option, - /// Timezone for quiet hours evaluation (IANA name, e.g. "America/New_York"). + /// Timezone for fire_at and quiet hours (IANA name, e.g. "Pacific/Auckland"). #[serde(default)] pub timezone: Option, } @@ -384,6 +388,7 @@ impl Default for HeartbeatSettings { interval_secs: default_heartbeat_interval(), notify_channel: None, notify_user: None, + fire_at: None, quiet_hours_start: None, quiet_hours_end: None, timezone: None, @@ -1747,4 +1752,503 @@ mod tests { "None selected_model should stay None" ); } + + // === Wizard re-run regression tests === + // + // These tests simulate the merge ordering used by the wizard's `run()` method + // to verify that re-running the wizard (or a subset of steps) doesn't + // accidentally reset settings from prior runs. + + /// Simulates `ironclaw onboard --provider-only` re-running on a fully + /// configured installation. Only provider + model should change; all + /// other settings (channels, embeddings, heartbeat) must survive. + #[test] + fn provider_only_rerun_preserves_unrelated_settings() { + // Prior completed run with everything configured + let prior = Settings { + onboard_completed: true, + database_backend: Some("libsql".to_string()), + libsql_path: Some("/home/user/.ironclaw/ironclaw.db".to_string()), + llm_backend: Some("openai".to_string()), + selected_model: Some("gpt-4o".to_string()), + embeddings: EmbeddingsSettings { + enabled: true, + provider: "openai".to_string(), + model: "text-embedding-3-small".to_string(), + }, + channels: ChannelSettings { + http_enabled: true, + http_port: Some(8080), + signal_enabled: true, + signal_account: Some("+1234567890".to_string()), + wasm_channels: vec!["telegram".to_string()], + ..Default::default() + }, + heartbeat: HeartbeatSettings { + enabled: true, + interval_secs: 900, + ..Default::default() + }, + ..Default::default() + }; + let db_map = prior.to_db_map(); + + // provider_only mode: reconnect_existing_db loads from DB, + // then user picks a new provider + model via step_inference_provider + let mut current = Settings::from_db_map(&db_map); + + // Simulate step_inference_provider: user switches to anthropic + current.llm_backend = Some("anthropic".to_string()); + current.selected_model = None; // cleared because backend changed + + // Simulate step_model_selection: user picks a model + current.selected_model = Some("claude-sonnet-4-5".to_string()); + + // Verify: provider/model changed + assert_eq!(current.llm_backend.as_deref(), Some("anthropic")); + assert_eq!(current.selected_model.as_deref(), Some("claude-sonnet-4-5")); + + // Verify: everything else preserved + assert!(current.channels.http_enabled, "HTTP channel must survive"); + assert_eq!(current.channels.http_port, Some(8080)); + assert!(current.channels.signal_enabled, "Signal must survive"); + assert_eq!( + current.channels.wasm_channels, + vec!["telegram".to_string()], + "WASM channels must survive" + ); + assert!(current.embeddings.enabled, "Embeddings must survive"); + assert_eq!(current.embeddings.provider, "openai"); + assert!(current.heartbeat.enabled, "Heartbeat must survive"); + assert_eq!(current.heartbeat.interval_secs, 900); + assert_eq!( + current.database_backend.as_deref(), + Some("libsql"), + "DB backend must survive" + ); + } + + /// Simulates `ironclaw onboard --channels-only` re-running on a fully + /// configured installation. Only channel settings should change; + /// provider, model, embeddings, heartbeat must survive. + #[test] + fn channels_only_rerun_preserves_unrelated_settings() { + let prior = Settings { + onboard_completed: true, + database_backend: Some("postgres".to_string()), + database_url: Some("postgres://host/db".to_string()), + llm_backend: Some("anthropic".to_string()), + selected_model: Some("claude-sonnet-4-5".to_string()), + embeddings: EmbeddingsSettings { + enabled: true, + provider: "nearai".to_string(), + model: "text-embedding-3-small".to_string(), + }, + heartbeat: HeartbeatSettings { + enabled: true, + interval_secs: 1800, + ..Default::default() + }, + channels: ChannelSettings { + http_enabled: false, + wasm_channels: vec!["telegram".to_string()], + ..Default::default() + }, + ..Default::default() + }; + let db_map = prior.to_db_map(); + + // channels_only mode: reconnect_existing_db loads from DB + let mut current = Settings::from_db_map(&db_map); + + // Simulate step_channels: user enables HTTP and adds discord + current.channels.http_enabled = true; + current.channels.http_port = Some(9090); + current.channels.wasm_channels = vec!["telegram".to_string(), "discord".to_string()]; + + // Verify: channels changed + assert!(current.channels.http_enabled); + assert_eq!(current.channels.http_port, Some(9090)); + assert_eq!(current.channels.wasm_channels.len(), 2); + + // Verify: everything else preserved + assert_eq!(current.llm_backend.as_deref(), Some("anthropic")); + assert_eq!(current.selected_model.as_deref(), Some("claude-sonnet-4-5")); + assert!(current.embeddings.enabled); + assert_eq!(current.embeddings.provider, "nearai"); + assert!(current.heartbeat.enabled); + assert_eq!(current.heartbeat.interval_secs, 1800); + } + + /// Simulates quick mode re-run on an installation that previously + /// completed a full setup. Quick mode only touches DB + security + + /// provider + model; channels, embeddings, heartbeat, extensions + /// should survive via the merge_from ordering. + #[test] + fn quick_mode_rerun_preserves_prior_channels_and_heartbeat() { + let prior = Settings { + onboard_completed: true, + database_backend: Some("libsql".to_string()), + libsql_path: Some("/home/user/.ironclaw/ironclaw.db".to_string()), + llm_backend: Some("openai".to_string()), + selected_model: Some("gpt-4o".to_string()), + channels: ChannelSettings { + http_enabled: true, + http_port: Some(8080), + signal_enabled: true, + wasm_channels: vec!["telegram".to_string()], + ..Default::default() + }, + embeddings: EmbeddingsSettings { + enabled: true, + provider: "openai".to_string(), + model: "text-embedding-3-small".to_string(), + }, + heartbeat: HeartbeatSettings { + enabled: true, + interval_secs: 600, + ..Default::default() + }, + ..Default::default() + }; + let db_map = prior.to_db_map(); + let from_db = Settings::from_db_map(&db_map); + + // Quick mode flow: + // 1. auto_setup_database sets DB fields + let step1 = Settings { + database_backend: Some("libsql".to_string()), + libsql_path: Some("/home/user/.ironclaw/ironclaw.db".to_string()), + ..Default::default() + }; + + // 2. try_load_existing_settings → merge DB → merge step1 on top + let mut current = step1.clone(); + current.merge_from(&from_db); + current.merge_from(&step1); + + // 3. step_inference_provider: user picks anthropic this time + current.llm_backend = Some("anthropic".to_string()); + current.selected_model = None; // cleared because backend changed + + // 4. step_model_selection: user picks model + current.selected_model = Some("claude-opus-4-6".to_string()); + + // Verify: provider/model updated + assert_eq!(current.llm_backend.as_deref(), Some("anthropic")); + assert_eq!(current.selected_model.as_deref(), Some("claude-opus-4-6")); + + // Verify: channels, embeddings, heartbeat survived quick mode + assert!( + current.channels.http_enabled, + "HTTP channel must survive quick mode re-run" + ); + assert_eq!(current.channels.http_port, Some(8080)); + assert!( + current.channels.signal_enabled, + "Signal must survive quick mode re-run" + ); + assert_eq!( + current.channels.wasm_channels, + vec!["telegram".to_string()], + "WASM channels must survive quick mode re-run" + ); + assert!( + current.embeddings.enabled, + "Embeddings must survive quick mode re-run" + ); + assert!( + current.heartbeat.enabled, + "Heartbeat must survive quick mode re-run" + ); + assert_eq!(current.heartbeat.interval_secs, 600); + } + + /// Full wizard re-run where user keeps the same provider. The model + /// selection from the prior run should be pre-populated (not reset). + /// + /// Regression: re-running with the same provider should preserve model. + #[test] + fn full_rerun_same_provider_preserves_model_through_merge() { + let prior = Settings { + onboard_completed: true, + database_backend: Some("postgres".to_string()), + database_url: Some("postgres://host/db".to_string()), + llm_backend: Some("anthropic".to_string()), + selected_model: Some("claude-sonnet-4-5".to_string()), + ..Default::default() + }; + let db_map = prior.to_db_map(); + let from_db = Settings::from_db_map(&db_map); + + // Step 1: user keeps same DB + let step1 = Settings { + database_backend: Some("postgres".to_string()), + database_url: Some("postgres://host/db".to_string()), + ..Default::default() + }; + + let mut current = step1.clone(); + current.merge_from(&from_db); + current.merge_from(&step1); + + // After merge, prior settings recovered + assert_eq!( + current.llm_backend.as_deref(), + Some("anthropic"), + "Prior provider must be recovered from DB" + ); + assert_eq!( + current.selected_model.as_deref(), + Some("claude-sonnet-4-5"), + "Prior model must be recovered from DB" + ); + + // Step 3: user picks same provider (anthropic) + // set_llm_backend_preserving_model checks if backend changed + let backend_changed = current.llm_backend.as_deref() != Some("anthropic"); + current.llm_backend = Some("anthropic".to_string()); + if backend_changed { + current.selected_model = None; + } + + // Model should NOT be cleared since backend didn't change + assert_eq!( + current.selected_model.as_deref(), + Some("claude-sonnet-4-5"), + "Model must survive when re-selecting same provider" + ); + } + + /// Full wizard re-run where user switches provider. Model should be + /// cleared since the old model is invalid for the new backend. + #[test] + fn full_rerun_different_provider_clears_model_through_merge() { + let prior = Settings { + onboard_completed: true, + database_backend: Some("postgres".to_string()), + database_url: Some("postgres://host/db".to_string()), + llm_backend: Some("anthropic".to_string()), + selected_model: Some("claude-sonnet-4-5".to_string()), + ..Default::default() + }; + let db_map = prior.to_db_map(); + let from_db = Settings::from_db_map(&db_map); + + // Step 1 merge + let step1 = Settings { + database_backend: Some("postgres".to_string()), + database_url: Some("postgres://host/db".to_string()), + ..Default::default() + }; + let mut current = step1.clone(); + current.merge_from(&from_db); + current.merge_from(&step1); + + // Step 3: user switches to openai + let backend_changed = current.llm_backend.as_deref() != Some("openai"); + assert!(backend_changed, "switching providers should be detected"); + current.llm_backend = Some("openai".to_string()); + if backend_changed { + current.selected_model = None; + } + + assert_eq!(current.llm_backend.as_deref(), Some("openai")); + assert!( + current.selected_model.is_none(), + "Model must be cleared when switching providers" + ); + } + + /// Simulates incremental save correctness: persist_after_step after + /// Step 3 (provider) should not clobber settings set in Step 2 (security). + /// + /// The wizard persists the full settings object after each step. This + /// test verifies that incremental saves are idempotent for prior steps. + #[test] + fn incremental_persist_does_not_clobber_prior_steps() { + // After steps 1-2, settings has DB + security + let after_step2 = Settings { + database_backend: Some("libsql".to_string()), + secrets_master_key_source: KeySource::Keychain, + ..Default::default() + }; + + // persist_after_step saves to DB + let db_map_after_step2 = after_step2.to_db_map(); + + // Step 3 adds provider + let mut after_step3 = after_step2.clone(); + after_step3.llm_backend = Some("openai".to_string()); + + // persist_after_step saves again — the full settings object + let db_map_after_step3 = after_step3.to_db_map(); + + // Reload from DB after step 3 + let restored = Settings::from_db_map(&db_map_after_step3); + + // Step 2's settings must survive step 3's persist + assert_eq!( + restored.secrets_master_key_source, + KeySource::Keychain, + "Step 2 security setting must survive step 3 persist" + ); + assert_eq!( + restored.database_backend.as_deref(), + Some("libsql"), + "Step 1 DB setting must survive step 3 persist" + ); + assert_eq!( + restored.llm_backend.as_deref(), + Some("openai"), + "Step 3 provider setting must be saved" + ); + + // Also verify that a partial step 2 reload doesn't regress + // (loading the step 2 snapshot and merging with step 3 state) + let from_step2_db = Settings::from_db_map(&db_map_after_step2); + let mut merged = after_step3.clone(); + merged.merge_from(&from_step2_db); + + assert_eq!( + merged.llm_backend.as_deref(), + Some("openai"), + "Step 3 provider must not be clobbered by step 2 snapshot merge" + ); + assert_eq!( + merged.secrets_master_key_source, + KeySource::Keychain, + "Step 2 security must survive merge" + ); + } + + /// Switching database backend should allow fresh connection settings. + /// When user switches from postgres to libsql, the old database_url + /// should not prevent the new libsql_path from being used. + #[test] + fn switching_db_backend_allows_fresh_connection_settings() { + let prior = Settings { + database_backend: Some("postgres".to_string()), + database_url: Some("postgres://host/db".to_string()), + llm_backend: Some("openai".to_string()), + selected_model: Some("gpt-4o".to_string()), + ..Default::default() + }; + let db_map = prior.to_db_map(); + let from_db = Settings::from_db_map(&db_map); + + // User picks libsql this time, wizard clears stale postgres settings + let step1 = Settings { + database_backend: Some("libsql".to_string()), + libsql_path: Some("/home/user/.ironclaw/ironclaw.db".to_string()), + database_url: None, // explicitly not set for libsql + ..Default::default() + }; + + let mut current = step1.clone(); + current.merge_from(&from_db); + current.merge_from(&step1); + + // libsql chosen + assert_eq!(current.database_backend.as_deref(), Some("libsql")); + assert_eq!( + current.libsql_path.as_deref(), + Some("/home/user/.ironclaw/ironclaw.db") + ); + + // Prior provider/model should survive (unrelated to DB switch) + assert_eq!(current.llm_backend.as_deref(), Some("openai")); + assert_eq!(current.selected_model.as_deref(), Some("gpt-4o")); + + // Note: database_url from prior run persists in merge because + // step1.database_url is None (== default), so merge_from doesn't + // override it. This is expected — the .env writer decides which + // vars to emit based on database_backend. The stale URL is + // harmless because the libsql backend ignores it. + assert_eq!( + current.database_url.as_deref(), + Some("postgres://host/db"), + "stale database_url persists (harmless, ignored by libsql backend)" + ); + } + + /// Regression: merge_from must handle boolean fields correctly. + /// A prior run with heartbeat.enabled=true must not be reset to false + /// when merging with a Settings that has heartbeat.enabled=false (default). + #[test] + fn merge_preserves_true_booleans_when_overlay_has_default_false() { + let prior = Settings { + heartbeat: HeartbeatSettings { + enabled: true, + interval_secs: 600, + ..Default::default() + }, + channels: ChannelSettings { + http_enabled: true, + signal_enabled: true, + ..Default::default() + }, + ..Default::default() + }; + let db_map = prior.to_db_map(); + let from_db = Settings::from_db_map(&db_map); + + // New wizard run only sets DB (everything else is default/false) + let step1 = Settings { + database_backend: Some("libsql".to_string()), + ..Default::default() + }; + + let mut current = step1.clone(); + current.merge_from(&from_db); + current.merge_from(&step1); + + // true booleans from prior run must survive + assert!( + current.heartbeat.enabled, + "heartbeat.enabled=true must not be reset to false by default overlay" + ); + assert!( + current.channels.http_enabled, + "http_enabled=true must not be reset to false by default overlay" + ); + assert!( + current.channels.signal_enabled, + "signal_enabled=true must not be reset to false by default overlay" + ); + assert_eq!(current.heartbeat.interval_secs, 600); + } + + /// Regression: embeddings settings (provider, model, enabled) must + /// survive a wizard re-run that doesn't touch step 5. + #[test] + fn embeddings_survive_rerun_that_skips_step5() { + let prior = Settings { + onboard_completed: true, + llm_backend: Some("nearai".to_string()), + selected_model: Some("qwen".to_string()), + embeddings: EmbeddingsSettings { + enabled: true, + provider: "nearai".to_string(), + model: "text-embedding-3-large".to_string(), + }, + ..Default::default() + }; + let db_map = prior.to_db_map(); + let from_db = Settings::from_db_map(&db_map); + + // Full re-run: step 1 only sets DB + let step1 = Settings { + database_backend: Some("libsql".to_string()), + ..Default::default() + }; + let mut current = step1.clone(); + current.merge_from(&from_db); + current.merge_from(&step1); + + // Before step 5 (embeddings) runs, check that prior values are present + assert!(current.embeddings.enabled); + assert_eq!(current.embeddings.provider, "nearai"); + assert_eq!(current.embeddings.model, "text-embedding-3-large"); + } } diff --git a/src/setup/README.md b/src/setup/README.md index a1a1d3aa2a..196b910d4f 100644 --- a/src/setup/README.md +++ b/src/setup/README.md @@ -114,6 +114,13 @@ Step 9: Background Tasks (heartbeat) **Goal:** Select backend, establish connection, run migrations. +**Init delegation:** Backend-specific connection logic lives in `src/db/mod.rs` +(`connect_without_migrations()`), not in the wizard. The wizard calls +`test_database_connection()` which delegates to the db module factory. Feature-flag +branching (`#[cfg(feature = ...)]`) is confined to `src/db/mod.rs`. PostgreSQL +validation (version >= 15, pgvector) is handled by `validate_postgres()` in +`src/db/mod.rs`. + **Decision tree:** ``` @@ -121,26 +128,23 @@ Both features compiled? ├─ Yes → DATABASE_BACKEND env var set? │ ├─ Yes → use that backend │ └─ No → interactive selection (PostgreSQL vs libSQL) -├─ Only postgres feature → step_database_postgres() -└─ Only libsql feature → step_database_libsql() +├─ Only postgres feature → prompt for DATABASE_URL, test connection +└─ Only libsql feature → prompt for path, test connection ``` -**PostgreSQL path** (`step_database_postgres`): +**PostgreSQL path:** 1. Check `DATABASE_URL` from env or settings -2. Test connection (creates `deadpool_postgres::Pool`) -3. Optionally run refinery migrations -4. Store pool in `self.db_pool` +2. Test connection via `connect_without_migrations()` (validates version, pgvector) +3. Optionally run migrations -**libSQL path** (`step_database_libsql`): +**libSQL path:** 1. Offer local path (default: `~/.ironclaw/ironclaw.db`) 2. Optional Turso cloud sync (URL + auth token) -3. Test connection (creates `LibSqlBackend`) +3. Test connection via `connect_without_migrations()` 4. Always run migrations (idempotent CREATE IF NOT EXISTS) -5. Store backend in `self.db_backend` -**Invariant:** After Step 1, exactly one of `self.db_pool` or -`self.db_backend` is `Some`. This is required for settings persistence -in `save_and_summarize()`. +**Invariant:** After Step 1, `self.db` is `Some(Arc)`. +This is required for settings persistence in `save_and_summarize()`. --- @@ -338,7 +342,7 @@ key first, then falls back to the standard env var. 1. Check `self.secrets_crypto` (set in Step 2) → use if available 2. Else try `SECRETS_MASTER_KEY` env var 3. Else try `get_master_key()` from keychain (only in `channels_only` mode) -4. Create backend-appropriate secrets store (respects selected database backend) +4. Create secrets store using `self.db` (`Arc`) --- diff --git a/src/setup/wizard.rs b/src/setup/wizard.rs index f8c695f156..9437d8279b 100644 --- a/src/setup/wizard.rs +++ b/src/setup/wizard.rs @@ -14,8 +14,6 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -#[cfg(feature = "postgres")] -use deadpool_postgres::Config as PoolConfig; use secrecy::{ExposeSecret, SecretString}; use crate::bootstrap::ironclaw_base_dir; @@ -23,8 +21,12 @@ use crate::channels::wasm::{ ChannelCapabilitiesFile, available_channel_names, install_bundled_channel, }; use crate::config::OAUTH_PLACEHOLDER; +use crate::llm::models::{ + build_nearai_model_fetch_config, fetch_anthropic_models, fetch_ollama_models, + fetch_openai_compatible_models, fetch_openai_models, +}; use crate::llm::{SessionConfig, SessionManager}; -use crate::secrets::{SecretsCrypto, SecretsStore}; +use crate::secrets::SecretsCrypto; use crate::settings::{KeySource, Settings}; use crate::setup::channels::{ SecretsContext, setup_http, setup_signal, setup_tunnel, setup_wasm_channel, @@ -85,12 +87,10 @@ pub struct SetupWizard { config: SetupConfig, settings: Settings, session_manager: Option>, - /// Database pool (created during setup, postgres only). - #[cfg(feature = "postgres")] - db_pool: Option, - /// libSQL backend (created during setup, libsql only). - #[cfg(feature = "libsql")] - db_backend: Option, + /// Backend-agnostic database trait object (created during setup). + db: Option>, + /// Backend-specific handles for secrets store and other satellite consumers. + db_handles: Option, /// Secrets crypto (created during setup). secrets_crypto: Option>, /// Cached API key from provider setup (used by model fetcher without env mutation). @@ -104,10 +104,8 @@ impl SetupWizard { config: SetupConfig::default(), settings: Settings::default(), session_manager: None, - #[cfg(feature = "postgres")] - db_pool: None, - #[cfg(feature = "libsql")] - db_backend: None, + db: None, + db_handles: None, secrets_crypto: None, llm_api_key: None, } @@ -119,10 +117,8 @@ impl SetupWizard { config, settings: Settings::default(), session_manager: None, - #[cfg(feature = "postgres")] - db_pool: None, - #[cfg(feature = "libsql")] - db_backend: None, + db: None, + db_handles: None, secrets_crypto: None, llm_api_key: None, } @@ -256,115 +252,79 @@ impl SetupWizard { /// database connection and the wizard's `self.settings` reflects the /// previously saved configuration. async fn reconnect_existing_db(&mut self) -> Result<(), SetupError> { - // Determine backend from env (set by bootstrap .env loaded in main). - let backend = std::env::var("DATABASE_BACKEND").unwrap_or_else(|_| "postgres".to_string()); - - // Try libsql first if that's the configured backend. - #[cfg(feature = "libsql")] - if backend == "libsql" || backend == "turso" || backend == "sqlite" { - return self.reconnect_libsql().await; - } - - // Try postgres (either explicitly configured or as default). - #[cfg(feature = "postgres")] - { - let _ = &backend; - return self.reconnect_postgres().await; - } + use crate::config::DatabaseConfig; - #[allow(unreachable_code)] - Err(SetupError::Database( - "No database configured. Run full setup first (ironclaw onboard).".to_string(), - )) - } - - /// Reconnect to an existing PostgreSQL database and load settings. - #[cfg(feature = "postgres")] - async fn reconnect_postgres(&mut self) -> Result<(), SetupError> { - let url = std::env::var("DATABASE_URL").map_err(|_| { - SetupError::Database( - "DATABASE_URL not set. Run full setup first (ironclaw onboard).".to_string(), - ) + let db_config = DatabaseConfig::resolve().map_err(|e| { + SetupError::Database(format!( + "Cannot resolve database config. Run full setup first (ironclaw onboard): {}", + e + )) })?; - self.test_database_connection_postgres(&url).await?; - self.settings.database_backend = Some("postgres".to_string()); - self.settings.database_url = Some(url.clone()); + let backend_name = db_config.backend.to_string(); + let (db, handles) = crate::db::connect_with_handles(&db_config) + .await + .map_err(|e| SetupError::Database(format!("Failed to connect: {}", e)))?; - // Load existing settings from DB, then restore connection fields that - // may not be persisted in the settings map. - if let Some(ref pool) = self.db_pool { - let store = crate::history::Store::from_pool(pool.clone()); - if let Ok(map) = store.get_all_settings("default").await { - self.settings = Settings::from_db_map(&map); - self.settings.database_backend = Some("postgres".to_string()); - self.settings.database_url = Some(url); - } + // Load existing settings from DB + if let Ok(map) = db.get_all_settings("default").await { + self.settings = Settings::from_db_map(&map); } - Ok(()) - } - - /// Reconnect to an existing libSQL database and load settings. - #[cfg(feature = "libsql")] - async fn reconnect_libsql(&mut self) -> Result<(), SetupError> { - let path = std::env::var("LIBSQL_PATH").unwrap_or_else(|_| { - crate::config::default_libsql_path() - .to_string_lossy() - .to_string() - }); - let turso_url = std::env::var("LIBSQL_URL").ok(); - let turso_token = std::env::var("LIBSQL_AUTH_TOKEN").ok(); - - self.test_database_connection_libsql(&path, turso_url.as_deref(), turso_token.as_deref()) - .await?; - - self.settings.database_backend = Some("libsql".to_string()); - self.settings.libsql_path = Some(path.clone()); - if let Some(ref url) = turso_url { - self.settings.libsql_url = Some(url.clone()); - } - - // Load existing settings from DB, then restore connection fields that - // may not be persisted in the settings map. - if let Some(ref db) = self.db_backend { - use crate::db::SettingsStore as _; - if let Ok(map) = db.get_all_settings("default").await { - self.settings = Settings::from_db_map(&map); - self.settings.database_backend = Some("libsql".to_string()); - self.settings.libsql_path = Some(path); - if let Some(url) = turso_url { - self.settings.libsql_url = Some(url); - } - } + // Restore connection fields that may not be persisted in the settings map + self.settings.database_backend = Some(backend_name); + if let Ok(url) = std::env::var("DATABASE_URL") { + self.settings.database_url = Some(url); + } + if let Ok(path) = std::env::var("LIBSQL_PATH") { + self.settings.libsql_path = Some(path); + } else if db_config.libsql_path.is_some() { + self.settings.libsql_path = db_config + .libsql_path + .as_ref() + .map(|p| p.to_string_lossy().to_string()); } + if let Ok(url) = std::env::var("LIBSQL_URL") { + self.settings.libsql_url = Some(url); + } + + self.db = Some(db); + self.db_handles = Some(handles); Ok(()) } /// Step 1: Database connection. + /// + /// Determines the backend at runtime (env var, interactive selection, or + /// compile-time default) and runs the appropriate configuration flow. async fn step_database(&mut self) -> Result<(), SetupError> { - // When both features are compiled, let the user choose. - // If DATABASE_BACKEND is already set in the environment, respect it. - #[cfg(all(feature = "postgres", feature = "libsql"))] - { - // Check if a backend is already pinned via env var - let env_backend = std::env::var("DATABASE_BACKEND").ok(); + use crate::config::{DatabaseBackend, DatabaseConfig}; - if let Some(ref backend) = env_backend { - if backend == "libsql" || backend == "turso" || backend == "sqlite" { - return self.step_database_libsql().await; - } - if backend != "postgres" && backend != "postgresql" { + const POSTGRES_AVAILABLE: bool = cfg!(feature = "postgres"); + const LIBSQL_AVAILABLE: bool = cfg!(feature = "libsql"); + + // Determine backend from env var, interactive selection, or default. + let env_backend = std::env::var("DATABASE_BACKEND").ok(); + + let backend = if let Some(ref raw) = env_backend { + match raw.parse::() { + Ok(b) => b, + Err(_) => { + let fallback = if POSTGRES_AVAILABLE { + DatabaseBackend::Postgres + } else { + DatabaseBackend::LibSql + }; print_info(&format!( - "Unknown DATABASE_BACKEND '{}', defaulting to PostgreSQL", - backend + "Unknown DATABASE_BACKEND '{}', defaulting to {}", + raw, fallback )); + fallback } - return self.step_database_postgres().await; } - - // Interactive selection + } else if POSTGRES_AVAILABLE && LIBSQL_AVAILABLE { + // Both features compiled — offer interactive selection. let pre_selected = self.settings.database_backend.as_deref().map(|b| match b { "libsql" | "turso" | "sqlite" => 1, _ => 0, @@ -390,88 +350,82 @@ impl SetupWizard { self.settings.libsql_url = None; } - match choice { - 1 => return self.step_database_libsql().await, - _ => return self.step_database_postgres().await, + if choice == 1 { + DatabaseBackend::LibSql + } else { + DatabaseBackend::Postgres } - } - - #[cfg(all(feature = "postgres", not(feature = "libsql")))] - { - return self.step_database_postgres().await; - } - - #[cfg(all(feature = "libsql", not(feature = "postgres")))] - { - return self.step_database_libsql().await; - } - } + } else if LIBSQL_AVAILABLE { + DatabaseBackend::LibSql + } else { + // Only postgres (or neither, but that won't compile anyway). + DatabaseBackend::Postgres + }; - /// Step 1 (postgres): Database connection via PostgreSQL URL. - #[cfg(feature = "postgres")] - async fn step_database_postgres(&mut self) -> Result<(), SetupError> { - self.settings.database_backend = Some("postgres".to_string()); + // --- Postgres flow --- + if backend == DatabaseBackend::Postgres { + self.settings.database_backend = Some("postgres".to_string()); - let existing_url = std::env::var("DATABASE_URL") - .ok() - .or_else(|| self.settings.database_url.clone()); + let existing_url = std::env::var("DATABASE_URL") + .ok() + .or_else(|| self.settings.database_url.clone()); - if let Some(ref url) = existing_url { - let display_url = mask_password_in_url(url); - print_info(&format!("Existing database URL: {}", display_url)); + if let Some(ref url) = existing_url { + let display_url = mask_password_in_url(url); + print_info(&format!("Existing database URL: {}", display_url)); - if confirm("Use this database?", true).map_err(SetupError::Io)? { - if let Err(e) = self.test_database_connection_postgres(url).await { - print_error(&format!("Connection failed: {}", e)); - print_info("Let's configure a new database URL."); - } else { - print_success("Database connection successful"); - self.settings.database_url = Some(url.clone()); - return Ok(()); + if confirm("Use this database?", true).map_err(SetupError::Io)? { + let config = DatabaseConfig::from_postgres_url(url, 5); + if let Err(e) = self.test_database_connection(&config).await { + print_error(&format!("Connection failed: {}", e)); + print_info("Let's configure a new database URL."); + } else { + print_success("Database connection successful"); + self.settings.database_url = Some(url.clone()); + return Ok(()); + } } } - } - println!(); - print_info("Enter your PostgreSQL connection URL."); - print_info("Format: postgres://user:password@host:port/database"); - println!(); + println!(); + print_info("Enter your PostgreSQL connection URL."); + print_info("Format: postgres://user:password@host:port/database"); + println!(); - loop { - let url = input("Database URL").map_err(SetupError::Io)?; + loop { + let url = input("Database URL").map_err(SetupError::Io)?; - if url.is_empty() { - print_error("Database URL is required."); - continue; - } + if url.is_empty() { + print_error("Database URL is required."); + continue; + } - print_info("Testing connection..."); - match self.test_database_connection_postgres(&url).await { - Ok(()) => { - print_success("Database connection successful"); + print_info("Testing connection..."); + let config = DatabaseConfig::from_postgres_url(&url, 5); + match self.test_database_connection(&config).await { + Ok(()) => { + print_success("Database connection successful"); - if confirm("Run database migrations?", true).map_err(SetupError::Io)? { - self.run_migrations_postgres().await?; - } + if confirm("Run database migrations?", true).map_err(SetupError::Io)? { + self.run_migrations().await?; + } - self.settings.database_url = Some(url); - return Ok(()); - } - Err(e) => { - print_error(&format!("Connection failed: {}", e)); - if !confirm("Try again?", true).map_err(SetupError::Io)? { - return Err(SetupError::Database( - "Database connection failed".to_string(), - )); + self.settings.database_url = Some(url); + return Ok(()); + } + Err(e) => { + print_error(&format!("Connection failed: {}", e)); + if !confirm("Try again?", true).map_err(SetupError::Io)? { + return Err(SetupError::Database( + "Database connection failed".to_string(), + )); + } } } } } - } - /// Step 1 (libsql): Database connection via local file or Turso remote replica. - #[cfg(feature = "libsql")] - async fn step_database_libsql(&mut self) -> Result<(), SetupError> { + // --- libSQL flow --- self.settings.database_backend = Some("libsql".to_string()); let default_path = crate::config::default_libsql_path(); @@ -490,14 +444,12 @@ impl SetupWizard { .or_else(|| self.settings.libsql_url.clone()); let turso_token = std::env::var("LIBSQL_AUTH_TOKEN").ok(); - match self - .test_database_connection_libsql( - path, - turso_url.as_deref(), - turso_token.as_deref(), - ) - .await - { + let config = DatabaseConfig::from_libsql_path( + path, + turso_url.as_deref(), + turso_token.as_deref(), + ); + match self.test_database_connection(&config).await { Ok(()) => { print_success("Database connection successful"); self.settings.libsql_path = Some(path.clone()); @@ -556,15 +508,17 @@ impl SetupWizard { }; print_info("Testing connection..."); - match self - .test_database_connection_libsql(&db_path, turso_url.as_deref(), turso_token.as_deref()) - .await - { + let config = DatabaseConfig::from_libsql_path( + &db_path, + turso_url.as_deref(), + turso_token.as_deref(), + ); + match self.test_database_connection(&config).await { Ok(()) => { print_success("Database connection successful"); // Always run migrations for libsql (they're idempotent) - self.run_migrations_libsql().await?; + self.run_migrations().await?; self.settings.libsql_path = Some(db_path); if let Some(url) = turso_url { @@ -576,155 +530,39 @@ impl SetupWizard { } } - /// Test PostgreSQL connection and store the pool. + /// Test database connection using the db module factory. /// - /// After connecting, validates: - /// 1. PostgreSQL version >= 15 (required for pgvector compatibility) - /// 2. pgvector extension is available (required for embeddings/vector search) - #[cfg(feature = "postgres")] - async fn test_database_connection_postgres(&mut self, url: &str) -> Result<(), SetupError> { - let mut cfg = PoolConfig::new(); - cfg.url = Some(url.to_string()); - cfg.pool = Some(deadpool_postgres::PoolConfig { - max_size: 5, - ..Default::default() - }); - - let pool = crate::db::tls::create_pool(&cfg, crate::config::SslMode::from_env()) - .map_err(|e| SetupError::Database(format!("Failed to create pool: {}", e)))?; - - let client = pool - .get() - .await - .map_err(|e| SetupError::Database(format!("Failed to connect: {}", e)))?; - - // Check PostgreSQL server version (need 15+ for pgvector) - let version_row = client - .query_one("SHOW server_version", &[]) - .await - .map_err(|e| SetupError::Database(format!("Failed to query server version: {}", e)))?; - let version_str: &str = version_row.get(0); - let major_version = version_str - .split('.') - .next() - .and_then(|v| v.parse::().ok()) - .unwrap_or(0); - - const MIN_PG_MAJOR_VERSION: u32 = 15; - - if major_version < MIN_PG_MAJOR_VERSION { - return Err(SetupError::Database(format!( - "PostgreSQL {} detected. IronClaw requires PostgreSQL {} or later for pgvector support.\n\ - Upgrade: https://www.postgresql.org/download/", - version_str, MIN_PG_MAJOR_VERSION - ))); - } - - // Check if pgvector extension is available - let pgvector_row = client - .query_opt( - "SELECT 1 FROM pg_available_extensions WHERE name = 'vector'", - &[], - ) - .await - .map_err(|e| { - SetupError::Database(format!("Failed to check pgvector availability: {}", e)) - })?; - - if pgvector_row.is_none() { - return Err(SetupError::Database(format!( - "pgvector extension not found on your PostgreSQL server.\n\n\ - Install it:\n \ - macOS: brew install pgvector\n \ - Ubuntu: apt install postgresql-{0}-pgvector\n \ - Docker: use the pgvector/pgvector:pg{0} image\n \ - Source: https://github.com/pgvector/pgvector#installation\n\n\ - Then restart PostgreSQL and re-run: ironclaw onboard", - major_version - ))); - } - - self.db_pool = Some(pool); - Ok(()) - } - - /// Test libSQL connection and store the backend. - #[cfg(feature = "libsql")] - async fn test_database_connection_libsql( + /// Connects without running migrations and validates PostgreSQL + /// prerequisites (version, pgvector) when using the postgres backend. + async fn test_database_connection( &mut self, - path: &str, - turso_url: Option<&str>, - turso_token: Option<&str>, + config: &crate::config::DatabaseConfig, ) -> Result<(), SetupError> { - use crate::db::libsql::LibSqlBackend; - use std::path::Path; - - let db_path = Path::new(path); - - let backend = if let (Some(url), Some(token)) = (turso_url, turso_token) { - LibSqlBackend::new_remote_replica(db_path, url, token) - .await - .map_err(|e| SetupError::Database(format!("Failed to connect: {}", e)))? - } else { - LibSqlBackend::new_local(db_path) - .await - .map_err(|e| SetupError::Database(format!("Failed to open database: {}", e)))? - }; - - self.db_backend = Some(backend); - Ok(()) - } - - /// Run PostgreSQL migrations. - #[cfg(feature = "postgres")] - async fn run_migrations_postgres(&self) -> Result<(), SetupError> { - if let Some(ref pool) = self.db_pool { - use refinery::embed_migrations; - embed_migrations!("migrations"); - - if !self.config.quick { - print_info("Running migrations..."); - } - tracing::debug!("Running PostgreSQL migrations..."); - - let mut client = pool - .get() - .await - .map_err(|e| SetupError::Database(format!("Pool error: {}", e)))?; - - migrations::runner() - .run_async(&mut **client) - .await - .map_err(|e| SetupError::Database(format!("Migration failed: {}", e)))?; + let (db, handles) = crate::db::connect_without_migrations(config) + .await + .map_err(|e| SetupError::Database(e.to_string()))?; - if !self.config.quick { - print_success("Migrations applied"); - } - tracing::debug!("PostgreSQL migrations applied"); - } + self.db = Some(db); + self.db_handles = Some(handles); Ok(()) } - /// Run libSQL migrations. - #[cfg(feature = "libsql")] - async fn run_migrations_libsql(&self) -> Result<(), SetupError> { - if let Some(ref backend) = self.db_backend { - use crate::db::Database; - + /// Run database migrations on the current connection. + async fn run_migrations(&self) -> Result<(), SetupError> { + if let Some(ref db) = self.db { if !self.config.quick { print_info("Running migrations..."); } - tracing::debug!("Running libSQL migrations..."); + tracing::debug!("Running database migrations..."); - backend - .run_migrations() + db.run_migrations() .await .map_err(|e| SetupError::Database(format!("Migration failed: {}", e)))?; if !self.config.quick { print_success("Migrations applied"); } - tracing::debug!("libSQL migrations applied"); + tracing::debug!("Database migrations applied"); } Ok(()) } @@ -741,20 +579,19 @@ impl SetupWizard { return Ok(()); } - // Try to retrieve existing key from keychain. We use get_master_key() - // instead of has_master_key() so we can cache the key bytes and build - // SecretsCrypto eagerly, avoiding redundant keychain accesses later - // (each access triggers macOS system dialogs). + // Try to retrieve existing key from keychain via resolve_master_key + // (checks env var first, then keychain). We skip the env var case + // above, so this will only find a keychain key here. print_info("Checking OS keychain for existing master key..."); if let Ok(keychain_key_bytes) = crate::secrets::keychain::get_master_key().await { let key_hex: String = keychain_key_bytes .iter() .map(|b| format!("{:02x}", b)) .collect(); - self.secrets_crypto = Some(Arc::new( - SecretsCrypto::new(SecretString::from(key_hex)) + self.secrets_crypto = Some( + crate::secrets::crypto_from_hex(&key_hex) .map_err(|e| SetupError::Config(e.to_string()))?, - )); + ); print_info("Existing master key found in OS keychain."); if confirm("Use existing keychain key?", true).map_err(SetupError::Io)? { @@ -793,12 +630,11 @@ impl SetupWizard { SetupError::Config(format!("Failed to store in keychain: {}", e)) })?; - // Also create crypto instance let key_hex: String = key.iter().map(|b| format!("{:02x}", b)).collect(); - self.secrets_crypto = Some(Arc::new( - SecretsCrypto::new(SecretString::from(key_hex)) + self.secrets_crypto = Some( + crate::secrets::crypto_from_hex(&key_hex) .map_err(|e| SetupError::Config(e.to_string()))?, - )); + ); self.settings.secrets_master_key_source = KeySource::Keychain; print_success("Master key generated and stored in OS keychain"); @@ -809,10 +645,10 @@ impl SetupWizard { // Initialize crypto so subsequent wizard steps (channel setup, // API key storage) can encrypt secrets immediately. - self.secrets_crypto = Some(Arc::new( - SecretsCrypto::new(SecretString::from(key_hex.clone())) + self.secrets_crypto = Some( + crate::secrets::crypto_from_hex(&key_hex) .map_err(|e| SetupError::Config(e.to_string()))?, - )); + ); // Make visible to optional_env() for any subsequent config resolution. crate::config::inject_single_var("SECRETS_MASTER_KEY", &key_hex); @@ -845,16 +681,22 @@ impl SetupWizard { /// standard path. Falls back to the interactive `step_database()` only when /// just the postgres feature is compiled (can't auto-default postgres). async fn auto_setup_database(&mut self) -> Result<(), SetupError> { - // If DATABASE_URL or LIBSQL_PATH already set, respect existing config - #[cfg(feature = "postgres")] + use crate::config::{DatabaseBackend, DatabaseConfig}; + + const POSTGRES_AVAILABLE: bool = cfg!(feature = "postgres"); + const LIBSQL_AVAILABLE: bool = cfg!(feature = "libsql"); + let env_backend = std::env::var("DATABASE_BACKEND").ok(); - #[cfg(feature = "postgres")] + // If DATABASE_BACKEND=postgres and DATABASE_URL exists: connect+migrate if let Some(ref backend) = env_backend - && (backend == "postgres" || backend == "postgresql") + && let Ok(DatabaseBackend::Postgres) = backend.parse::() { if let Ok(url) = std::env::var("DATABASE_URL") { print_info("Using existing PostgreSQL configuration"); + let config = DatabaseConfig::from_postgres_url(&url, 5); + self.test_database_connection(&config).await?; + self.run_migrations().await?; self.settings.database_backend = Some("postgres".to_string()); self.settings.database_url = Some(url); return Ok(()); @@ -863,17 +705,23 @@ impl SetupWizard { return self.step_database().await; } - #[cfg(feature = "postgres")] - if let Ok(url) = std::env::var("DATABASE_URL") { + // If DATABASE_URL exists (no explicit backend): connect+migrate as postgres, + // but only when the postgres feature is actually compiled in. + if POSTGRES_AVAILABLE + && env_backend.is_none() + && let Ok(url) = std::env::var("DATABASE_URL") + { print_info("Using existing PostgreSQL configuration"); + let config = DatabaseConfig::from_postgres_url(&url, 5); + self.test_database_connection(&config).await?; + self.run_migrations().await?; self.settings.database_backend = Some("postgres".to_string()); self.settings.database_url = Some(url); return Ok(()); } - // Auto-default to libsql if the feature is compiled - #[cfg(feature = "libsql")] - { + // Auto-default to libsql if available + if LIBSQL_AVAILABLE { self.settings.database_backend = Some("libsql".to_string()); let existing_path = std::env::var("LIBSQL_PATH") @@ -889,14 +737,13 @@ impl SetupWizard { let turso_url = std::env::var("LIBSQL_URL").ok(); let turso_token = std::env::var("LIBSQL_AUTH_TOKEN").ok(); - self.test_database_connection_libsql( + let config = DatabaseConfig::from_libsql_path( &db_path, turso_url.as_deref(), turso_token.as_deref(), - ) - .await?; - - self.run_migrations_libsql().await?; + ); + self.test_database_connection(&config).await?; + self.run_migrations().await?; self.settings.libsql_path = Some(db_path.clone()); if let Some(url) = turso_url { @@ -908,10 +755,7 @@ impl SetupWizard { } // Only postgres feature compiled — can't auto-default, use interactive - #[allow(unreachable_code)] - { - self.step_database().await - } + self.step_database().await } /// Auto-setup security with zero prompts (quick mode). @@ -920,26 +764,23 @@ impl SetupWizard { /// key if available, otherwise generates and stores one automatically /// (keychain on macOS, env var fallback). async fn auto_setup_security(&mut self) -> Result<(), SetupError> { - // Check env var first - if std::env::var("SECRETS_MASTER_KEY").is_ok() { - self.settings.secrets_master_key_source = KeySource::Env; - print_success("Security configured (env var)"); - return Ok(()); - } - - // Try existing keychain key (no prompts — get_master_key may show - // OS dialogs on macOS, but that's unavoidable for keychain access) - if let Ok(keychain_key_bytes) = crate::secrets::keychain::get_master_key().await { - let key_hex: String = keychain_key_bytes - .iter() - .map(|b| format!("{:02x}", b)) - .collect(); - self.secrets_crypto = Some(Arc::new( - SecretsCrypto::new(SecretString::from(key_hex)) + // Try resolving an existing key from env var or keychain + if let Some(key_hex) = crate::secrets::resolve_master_key().await { + self.secrets_crypto = Some( + crate::secrets::crypto_from_hex(&key_hex) .map_err(|e| SetupError::Config(e.to_string()))?, - )); - self.settings.secrets_master_key_source = KeySource::Keychain; - print_success("Security configured (keychain)"); + ); + // Determine source: env var or keychain (filter empty to match resolve_master_key) + let (source, label) = if std::env::var("SECRETS_MASTER_KEY") + .ok() + .is_some_and(|v| !v.is_empty()) + { + (KeySource::Env, "env var") + } else { + (KeySource::Keychain, "keychain") + }; + self.settings.secrets_master_key_source = source; + print_success(&format!("Security configured ({})", label)); return Ok(()); } @@ -951,10 +792,10 @@ impl SetupWizard { .is_ok() { let key_hex: String = key.iter().map(|b| format!("{:02x}", b)).collect(); - self.secrets_crypto = Some(Arc::new( - SecretsCrypto::new(SecretString::from(key_hex)) + self.secrets_crypto = Some( + crate::secrets::crypto_from_hex(&key_hex) .map_err(|e| SetupError::Config(e.to_string()))?, - )); + ); self.settings.secrets_master_key_source = KeySource::Keychain; print_success("Master key stored in OS keychain"); return Ok(()); @@ -962,10 +803,10 @@ impl SetupWizard { // Keychain unavailable — fall back to env var mode let key_hex = crate::secrets::keychain::generate_master_key_hex(); - self.secrets_crypto = Some(Arc::new( - SecretsCrypto::new(SecretString::from(key_hex.clone())) + self.secrets_crypto = Some( + crate::secrets::crypto_from_hex(&key_hex) .map_err(|e| SetupError::Config(e.to_string()))?, - )); + ); crate::config::inject_single_var("SECRETS_MASTER_KEY", &key_hex); self.settings.secrets_master_key_hex = Some(key_hex); self.settings.secrets_master_key_source = KeySource::Env; @@ -1836,74 +1677,27 @@ impl SetupWizard { /// Initialize secrets context for channel setup. async fn init_secrets_context(&mut self) -> Result { - // Get crypto (should be set from step 2, or load from keychain/env) + // Get crypto (should be set from step 2, or resolve from keychain/env) let crypto = if let Some(ref c) = self.secrets_crypto { Arc::clone(c) } else { - // Try to load master key from keychain or env - let key = if let Ok(env_key) = std::env::var("SECRETS_MASTER_KEY") { - env_key - } else if let Ok(keychain_key) = crate::secrets::keychain::get_master_key().await { - keychain_key.iter().map(|b| format!("{:02x}", b)).collect() - } else { - return Err(SetupError::Config( + let key_hex = crate::secrets::resolve_master_key().await.ok_or_else(|| { + SetupError::Config( "Secrets not configured. Run full setup or set SECRETS_MASTER_KEY.".to_string(), - )); - }; + ) + })?; - let crypto = Arc::new( - SecretsCrypto::new(SecretString::from(key)) - .map_err(|e| SetupError::Config(e.to_string()))?, - ); + let crypto = crate::secrets::crypto_from_hex(&key_hex) + .map_err(|e| SetupError::Config(e.to_string()))?; self.secrets_crypto = Some(Arc::clone(&crypto)); crypto }; - // Create backend-appropriate secrets store. - // Use runtime dispatch based on the user's selected backend. - // Default to whichever backend is compiled in. When only libsql is - // available, we must not default to "postgres" or we'd skip store creation. - let default_backend = { - #[cfg(feature = "postgres")] - { - "postgres" - } - #[cfg(not(feature = "postgres"))] - { - "libsql" - } - }; - let selected_backend = self - .settings - .database_backend - .as_deref() - .unwrap_or(default_backend); - - match selected_backend { - #[cfg(feature = "libsql")] - "libsql" | "turso" | "sqlite" => { - if let Some(store) = self.create_libsql_secrets_store(&crypto)? { - return Ok(SecretsContext::from_store(store, "default")); - } - // Fallback to postgres if libsql store creation returned None - #[cfg(feature = "postgres")] - if let Some(store) = self.create_postgres_secrets_store(&crypto).await? { - return Ok(SecretsContext::from_store(store, "default")); - } - } - #[cfg(feature = "postgres")] - _ => { - if let Some(store) = self.create_postgres_secrets_store(&crypto).await? { - return Ok(SecretsContext::from_store(store, "default")); - } - // Fallback to libsql if postgres store creation returned None - #[cfg(feature = "libsql")] - if let Some(store) = self.create_libsql_secrets_store(&crypto)? { - return Ok(SecretsContext::from_store(store, "default")); - } - } - #[cfg(not(feature = "postgres"))] - _ => {} + // Create secrets store from existing database handles + if let Some(ref handles) = self.db_handles + && let Some(store) = crate::secrets::create_secrets_store(Arc::clone(&crypto), handles) + { + return Ok(SecretsContext::from_store(store, "default")); } Err(SetupError::Config( @@ -1911,62 +1705,6 @@ impl SetupWizard { )) } - /// Create a PostgreSQL secrets store from the current pool. - #[cfg(feature = "postgres")] - async fn create_postgres_secrets_store( - &mut self, - crypto: &Arc, - ) -> Result>, SetupError> { - let pool = if let Some(ref p) = self.db_pool { - p.clone() - } else { - // Fall back to creating one from settings/env - let url = self - .settings - .database_url - .clone() - .or_else(|| std::env::var("DATABASE_URL").ok()); - - if let Some(url) = url { - self.test_database_connection_postgres(&url).await?; - self.run_migrations_postgres().await?; - match self.db_pool.clone() { - Some(pool) => pool, - None => { - return Err(SetupError::Database( - "Database pool not initialized after connection test".to_string(), - )); - } - } - } else { - return Ok(None); - } - }; - - let store: Arc = Arc::new(crate::secrets::PostgresSecretsStore::new( - pool, - Arc::clone(crypto), - )); - Ok(Some(store)) - } - - /// Create a libSQL secrets store from the current backend. - #[cfg(feature = "libsql")] - fn create_libsql_secrets_store( - &self, - crypto: &Arc, - ) -> Result>, SetupError> { - if let Some(ref backend) = self.db_backend { - let store: Arc = Arc::new(crate::secrets::LibSqlSecretsStore::new( - backend.shared_db(), - Arc::clone(crypto), - )); - Ok(Some(store)) - } else { - Ok(None) - } - } - /// Step 6: Channel configuration. async fn step_channels(&mut self) -> Result<(), SetupError> { // First, configure tunnel (shared across all channels that need webhooks) @@ -2484,45 +2222,15 @@ impl SetupWizard { /// connection is available yet (e.g., before Step 1 completes). async fn persist_settings(&self) -> Result { let db_map = self.settings.to_db_map(); - let saved = false; - - #[cfg(feature = "postgres")] - let saved = if !saved { - if let Some(ref pool) = self.db_pool { - let store = crate::history::Store::from_pool(pool.clone()); - store - .set_all_settings("default", &db_map) - .await - .map_err(|e| { - SetupError::Database(format!("Failed to save settings to database: {}", e)) - })?; - true - } else { - false - } - } else { - saved - }; - #[cfg(feature = "libsql")] - let saved = if !saved { - if let Some(ref backend) = self.db_backend { - use crate::db::SettingsStore as _; - backend - .set_all_settings("default", &db_map) - .await - .map_err(|e| { - SetupError::Database(format!("Failed to save settings to database: {}", e)) - })?; - true - } else { - false - } + if let Some(ref db) = self.db { + db.set_all_settings("default", &db_map).await.map_err(|e| { + SetupError::Database(format!("Failed to save settings to database: {}", e)) + })?; + Ok(true) } else { - saved - }; - - Ok(saved) + Ok(false) + } } /// Write bootstrap environment variables to `~/.ironclaw/.env`. @@ -2698,28 +2406,12 @@ impl SetupWizard { Err(_) => return, }; - #[cfg(feature = "postgres")] - if let Some(ref pool) = self.db_pool { - let store = crate::history::Store::from_pool(pool.clone()); - if let Err(e) = store - .set_setting("default", "nearai.session_token", &value) - .await - { - tracing::debug!("Could not persist session token to postgres: {}", e); - } else { - tracing::debug!("Session token persisted to database"); - return; - } - } - - #[cfg(feature = "libsql")] - if let Some(ref backend) = self.db_backend { - use crate::db::SettingsStore as _; - if let Err(e) = backend + if let Some(ref db) = self.db { + if let Err(e) = db .set_setting("default", "nearai.session_token", &value) .await { - tracing::debug!("Could not persist session token to libsql: {}", e); + tracing::debug!("Could not persist session token to database: {}", e); } else { tracing::debug!("Session token persisted to database"); } @@ -2756,58 +2448,19 @@ impl SetupWizard { /// prefers the `other` argument's non-default values. Without this, /// stale DB values would overwrite fresh user choices. async fn try_load_existing_settings(&mut self) { - let loaded = false; - - #[cfg(feature = "postgres")] - let loaded = if !loaded { - if let Some(ref pool) = self.db_pool { - let store = crate::history::Store::from_pool(pool.clone()); - match store.get_all_settings("default").await { - Ok(db_map) if !db_map.is_empty() => { - let existing = Settings::from_db_map(&db_map); - self.settings.merge_from(&existing); - tracing::info!("Loaded {} existing settings from database", db_map.len()); - true - } - Ok(_) => false, - Err(e) => { - tracing::debug!("Could not load existing settings: {}", e); - false - } + if let Some(ref db) = self.db { + match db.get_all_settings("default").await { + Ok(db_map) if !db_map.is_empty() => { + let existing = Settings::from_db_map(&db_map); + self.settings.merge_from(&existing); + tracing::info!("Loaded {} existing settings from database", db_map.len()); } - } else { - false - } - } else { - loaded - }; - - #[cfg(feature = "libsql")] - let loaded = if !loaded { - if let Some(ref backend) = self.db_backend { - use crate::db::SettingsStore as _; - match backend.get_all_settings("default").await { - Ok(db_map) if !db_map.is_empty() => { - let existing = Settings::from_db_map(&db_map); - self.settings.merge_from(&existing); - tracing::info!("Loaded {} existing settings from database", db_map.len()); - true - } - Ok(_) => false, - Err(e) => { - tracing::debug!("Could not load existing settings: {}", e); - false - } + Ok(_) => {} + Err(e) => { + tracing::debug!("Could not load existing settings: {}", e); } - } else { - false } - } else { - loaded - }; - - // Suppress unused variable warning when only one backend is compiled. - let _ = loaded; + } } /// Save settings to the database and `~/.ironclaw/.env`, then print summary. @@ -2957,7 +2610,6 @@ impl Default for SetupWizard { } /// Mask password in a database URL for display. -#[cfg(feature = "postgres")] fn mask_password_in_url(url: &str) -> String { // URL format: scheme://user:password@host/database // Find "://" to locate start of credentials @@ -2986,331 +2638,6 @@ fn mask_password_in_url(url: &str) -> String { format!("{}{}:****{}", scheme, username, after_at) } -/// Fetch models from the Anthropic API. -/// -/// Returns `(model_id, display_label)` pairs. Falls back to static defaults on error. -async fn fetch_anthropic_models(cached_key: Option<&str>) -> Vec<(String, String)> { - let static_defaults = vec![ - ( - "claude-opus-4-6".into(), - "Claude Opus 4.6 (latest flagship)".into(), - ), - ("claude-sonnet-4-6".into(), "Claude Sonnet 4.6".into()), - ("claude-opus-4-5".into(), "Claude Opus 4.5".into()), - ("claude-sonnet-4-5".into(), "Claude Sonnet 4.5".into()), - ("claude-haiku-4-5".into(), "Claude Haiku 4.5 (fast)".into()), - ]; - - let api_key = cached_key - .map(String::from) - .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok()) - .filter(|k| !k.is_empty() && k != crate::config::OAUTH_PLACEHOLDER); - - // Fall back to OAuth token if no API key - let oauth_token = if api_key.is_none() { - crate::config::helpers::optional_env("ANTHROPIC_OAUTH_TOKEN") - .ok() - .flatten() - .filter(|t| !t.is_empty()) - } else { - None - }; - - let (key_or_token, is_oauth) = match (api_key, oauth_token) { - (Some(k), _) => (k, false), - (None, Some(t)) => (t, true), - (None, None) => return static_defaults, - }; - - let client = reqwest::Client::new(); - let mut request = client - .get("https://api.anthropic.com/v1/models") - .header("anthropic-version", "2023-06-01") - .timeout(std::time::Duration::from_secs(5)); - - if is_oauth { - request = request - .bearer_auth(&key_or_token) - .header("anthropic-beta", "oauth-2025-04-20"); - } else { - request = request.header("x-api-key", &key_or_token); - } - - let resp = match request.send().await { - Ok(r) if r.status().is_success() => r, - _ => return static_defaults, - }; - - #[derive(serde::Deserialize)] - struct ModelEntry { - id: String, - } - #[derive(serde::Deserialize)] - struct ModelsResponse { - data: Vec, - } - - match resp.json::().await { - Ok(body) => { - let mut models: Vec<(String, String)> = body - .data - .into_iter() - .filter(|m| !m.id.contains("embedding") && !m.id.contains("audio")) - .map(|m| { - let label = m.id.clone(); - (m.id, label) - }) - .collect(); - if models.is_empty() { - return static_defaults; - } - models.sort_by(|a, b| a.0.cmp(&b.0)); - models - } - Err(_) => static_defaults, - } -} - -/// Fetch models from the OpenAI API. -/// -/// Returns `(model_id, display_label)` pairs. Falls back to static defaults on error. -async fn fetch_openai_models(cached_key: Option<&str>) -> Vec<(String, String)> { - let static_defaults = vec![ - ( - "gpt-5.3-codex".into(), - "GPT-5.3 Codex (latest flagship)".into(), - ), - ("gpt-5.2-codex".into(), "GPT-5.2 Codex".into()), - ("gpt-5.2".into(), "GPT-5.2".into()), - ( - "gpt-5.1-codex-mini".into(), - "GPT-5.1 Codex Mini (fast)".into(), - ), - ("gpt-5".into(), "GPT-5".into()), - ("gpt-5-mini".into(), "GPT-5 Mini".into()), - ("gpt-4.1".into(), "GPT-4.1".into()), - ("gpt-4.1-mini".into(), "GPT-4.1 Mini".into()), - ("o4-mini".into(), "o4-mini (fast reasoning)".into()), - ("o3".into(), "o3 (reasoning)".into()), - ]; - - let api_key = cached_key - .map(String::from) - .or_else(|| std::env::var("OPENAI_API_KEY").ok()) - .filter(|k| !k.is_empty()); - - let api_key = match api_key { - Some(k) => k, - None => return static_defaults, - }; - - let client = reqwest::Client::new(); - let resp = match client - .get("https://api.openai.com/v1/models") - .bearer_auth(&api_key) - .timeout(std::time::Duration::from_secs(5)) - .send() - .await - { - Ok(r) if r.status().is_success() => r, - _ => return static_defaults, - }; - - #[derive(serde::Deserialize)] - struct ModelEntry { - id: String, - } - #[derive(serde::Deserialize)] - struct ModelsResponse { - data: Vec, - } - - match resp.json::().await { - Ok(body) => { - let mut models: Vec<(String, String)> = body - .data - .into_iter() - .filter(|m| is_openai_chat_model(&m.id)) - .map(|m| { - let label = m.id.clone(); - (m.id, label) - }) - .collect(); - if models.is_empty() { - return static_defaults; - } - sort_openai_models(&mut models); - models - } - Err(_) => static_defaults, - } -} - -fn is_openai_chat_model(model_id: &str) -> bool { - let id = model_id.to_ascii_lowercase(); - - let is_chat_family = id.starts_with("gpt-") - || id.starts_with("chatgpt-") - || id.starts_with("o1") - || id.starts_with("o3") - || id.starts_with("o4") - || id.starts_with("o5"); - - let is_non_chat_variant = id.contains("realtime") - || id.contains("audio") - || id.contains("transcribe") - || id.contains("tts") - || id.contains("embedding") - || id.contains("moderation") - || id.contains("image"); - - is_chat_family && !is_non_chat_variant -} - -fn openai_model_priority(model_id: &str) -> usize { - let id = model_id.to_ascii_lowercase(); - - const EXACT_PRIORITY: &[&str] = &[ - "gpt-5.3-codex", - "gpt-5.2-codex", - "gpt-5.2", - "gpt-5.1-codex-mini", - "gpt-5", - "gpt-5-mini", - "gpt-5-nano", - "o4-mini", - "o3", - "o1", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4o", - "gpt-4o-mini", - ]; - if let Some(pos) = EXACT_PRIORITY.iter().position(|m| id == *m) { - return pos; - } - - const PREFIX_PRIORITY: &[&str] = &[ - "gpt-5.", "gpt-5-", "o3-", "o4-", "o1-", "gpt-4.1-", "gpt-4o-", "gpt-3.5-", "chatgpt-", - ]; - if let Some(pos) = PREFIX_PRIORITY - .iter() - .position(|prefix| id.starts_with(prefix)) - { - return EXACT_PRIORITY.len() + pos; - } - - EXACT_PRIORITY.len() + PREFIX_PRIORITY.len() + 1 -} - -fn sort_openai_models(models: &mut [(String, String)]) { - models.sort_by(|a, b| { - openai_model_priority(&a.0) - .cmp(&openai_model_priority(&b.0)) - .then_with(|| a.0.cmp(&b.0)) - }); -} - -/// Fetch installed models from a local Ollama instance. -/// -/// Returns `(model_name, display_label)` pairs. Falls back to static defaults on error. -async fn fetch_ollama_models(base_url: &str) -> Vec<(String, String)> { - let static_defaults = vec![ - ("llama3".into(), "llama3".into()), - ("mistral".into(), "mistral".into()), - ("codellama".into(), "codellama".into()), - ]; - - let url = format!("{}/api/tags", base_url.trim_end_matches('/')); - let client = reqwest::Client::new(); - - let resp = match client - .get(&url) - .timeout(std::time::Duration::from_secs(5)) - .send() - .await - { - Ok(r) if r.status().is_success() => r, - Ok(_) => return static_defaults, - Err(_) => { - print_info("Could not connect to Ollama. Is it running?"); - return static_defaults; - } - }; - - #[derive(serde::Deserialize)] - struct ModelEntry { - name: String, - } - #[derive(serde::Deserialize)] - struct TagsResponse { - models: Vec, - } - - match resp.json::().await { - Ok(body) => { - let models: Vec<(String, String)> = body - .models - .into_iter() - .map(|m| { - let label = m.name.clone(); - (m.name, label) - }) - .collect(); - if models.is_empty() { - return static_defaults; - } - models - } - Err(_) => static_defaults, - } -} - -/// Fetch models from a generic OpenAI-compatible /v1/models endpoint. -/// -/// Used for registry providers like Groq, NVIDIA NIM, etc. -async fn fetch_openai_compatible_models( - base_url: &str, - cached_key: Option<&str>, -) -> Vec<(String, String)> { - if base_url.is_empty() { - return vec![]; - } - - let url = format!("{}/models", base_url.trim_end_matches('/')); - let client = reqwest::Client::new(); - let mut req = client.get(&url).timeout(std::time::Duration::from_secs(5)); - if let Some(key) = cached_key { - req = req.bearer_auth(key); - } - - let resp = match req.send().await { - Ok(r) if r.status().is_success() => r, - _ => return vec![], - }; - - #[derive(serde::Deserialize)] - struct Model { - id: String, - } - #[derive(serde::Deserialize)] - struct ModelsResponse { - data: Vec, - } - - match resp.json::().await { - Ok(body) => body - .data - .into_iter() - .map(|m| { - let label = m.id.clone(); - (m.id, label) - }) - .collect(), - Err(_) => vec![], - } -} - /// Discover WASM channels in a directory. /// /// Returns a list of (channel_name, capabilities_file) pairs. @@ -3380,58 +2707,6 @@ async fn discover_wasm_channels(dir: &std::path::Path) -> Vec<(String, ChannelCa /// Mask an API key for display: show first 6 + last 4 chars. /// /// Uses char-based indexing to avoid panicking on multi-byte UTF-8. -/// Build the `LlmConfig` used by `fetch_nearai_models` to list available models. -/// -/// Reads `NEARAI_API_KEY` from the environment so that users who authenticated -/// via Cloud API key (option 4) don't get re-prompted during model selection. -fn build_nearai_model_fetch_config() -> crate::config::LlmConfig { - // If the user authenticated via API key (option 4), the key is stored - // as an env var. Pass it through so `resolve_bearer_token()` doesn't - // re-trigger the interactive auth prompt. - let api_key = std::env::var("NEARAI_API_KEY") - .ok() - .filter(|k| !k.is_empty()) - .map(secrecy::SecretString::from); - - // Match the same base_url logic as LlmConfig::resolve(): use cloud-api - // when an API key is present, private.near.ai for session-token auth. - let default_base = if api_key.is_some() { - "https://cloud-api.near.ai" - } else { - "https://private.near.ai" - }; - let base_url = std::env::var("NEARAI_BASE_URL").unwrap_or_else(|_| default_base.to_string()); - let auth_base_url = - std::env::var("NEARAI_AUTH_URL").unwrap_or_else(|_| "https://private.near.ai".to_string()); - - crate::config::LlmConfig { - backend: "nearai".to_string(), - session: crate::llm::session::SessionConfig { - auth_base_url, - session_path: crate::config::llm::default_session_path(), - }, - nearai: crate::config::NearAiConfig { - model: "dummy".to_string(), - cheap_model: None, - base_url, - api_key, - fallback_model: None, - max_retries: 3, - circuit_breaker_threshold: None, - circuit_breaker_recovery_secs: 30, - response_cache_enabled: false, - response_cache_ttl_secs: 3600, - response_cache_max_entries: 1000, - failover_cooldown_secs: 300, - failover_cooldown_threshold: 3, - smart_routing_cascade: true, - }, - provider: None, - bedrock: None, - request_timeout_secs: 120, - } -} - fn mask_api_key(key: &str) -> String { let chars: Vec = key.chars().collect(); if chars.len() < 12 { @@ -3641,6 +2916,7 @@ mod tests { use super::*; use crate::config::helpers::ENV_MUTEX; + use crate::llm::models::{is_openai_chat_model, sort_openai_models}; #[test] fn test_wizard_creation() { @@ -3662,7 +2938,6 @@ mod tests { } #[test] - #[cfg(feature = "postgres")] fn test_mask_password_in_url() { assert_eq!( mask_password_in_url("postgres://user:secret@localhost/db"), diff --git a/src/tools/builtin/job.rs b/src/tools/builtin/job.rs index 8744f75b94..9346d14ab1 100644 --- a/src/tools/builtin/job.rs +++ b/src/tools/builtin/job.rs @@ -415,7 +415,19 @@ impl CreateJobTool { // loop stops consuming from inject_tx the send will fail and the // monitor terminates. No JoinHandle is retained. if let (Some(etx), Some(itx)) = (&self.event_tx, &self.inject_tx) { - crate::agent::job_monitor::spawn_job_monitor(job_id, etx.subscribe(), itx.clone()); + if let Some(route) = monitor_route_from_ctx(ctx) { + crate::agent::job_monitor::spawn_job_monitor( + job_id, + etx.subscribe(), + itx.clone(), + route, + ); + } else { + tracing::debug!( + job_id = %job_id, + "Skipping job monitor injection due to missing route metadata" + ); + } } let result = serde_json::json!({ @@ -680,6 +692,36 @@ fn resolve_project_dir( Ok((canonical_dir, browse_id)) } +fn monitor_route_from_ctx(ctx: &JobContext) -> Option { + // notify_channel is required — without it we don't know which channel to + // route the monitor output to, so return None to skip monitoring entirely. + let channel = ctx + .metadata + .get("notify_channel") + .and_then(|v| v.as_str())? + .to_string(); + // notify_user is optional — fall back to the job's own user_id, which is + // always present. The channel is the routing decision; the user is just + // for attribution and can default safely. + let user_id = ctx + .metadata + .get("notify_user") + .and_then(|v| v.as_str()) + .unwrap_or(&ctx.user_id) + .to_string(); + let thread_id = ctx + .metadata + .get("notify_thread_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + Some(crate::agent::job_monitor::JobMonitorRoute { + channel, + user_id, + thread_id, + }) +} + #[async_trait] impl Tool for CreateJobTool { fn name(&self) -> &str { diff --git a/src/worker/job.rs b/src/worker/job.rs index 1247a5522b..0f0e969ee7 100644 --- a/src/worker/job.rs +++ b/src/worker/job.rs @@ -1170,11 +1170,16 @@ impl<'a> LoopDelegate for JobDelegate<'a> { // Reset counter after a successful LLM call self.consecutive_rate_limits .store(0, std::sync::atomic::Ordering::Relaxed); + // Preserve the LLM's reasoning text so it appears in the + // assistant_with_tool_calls message pushed by execute_tool_calls. + let reasoning_text = s + .iter() + .find_map(|sel| (!sel.reasoning.is_empty()).then_some(sel.reasoning.clone())); let tool_calls: Vec = selections_to_tool_calls(&s); return Ok(crate::llm::RespondOutput { result: RespondResult::ToolCalls { tool_calls, - content: None, + content: reasoning_text, }, usage: crate::llm::TokenUsage::default(), }); @@ -1586,7 +1591,7 @@ mod tests { } #[tokio::test] - async fn test_mark_completed_twice_returns_error() { + async fn test_mark_completed_twice_is_idempotent() { let worker = make_worker(vec![]).await; worker @@ -1607,11 +1612,22 @@ mod tests { .unwrap(); assert_eq!(ctx.state, JobState::Completed); + // Second mark_completed should succeed (idempotent) rather than + // erroring, matching the fix for the execution_loop / worker wrapper + // race condition. let result = worker.mark_completed().await; assert!( - result.is_err(), - "Completed → Completed transition should be rejected by state machine" + result.is_ok(), + "Completed -> Completed transition should be idempotent" ); + + // State should still be Completed + let ctx = worker + .context_manager() + .get_context(worker.job_id) + .await + .unwrap(); + assert_eq!(ctx.state, JobState::Completed); } /// Build a Worker with the given approval context. @@ -1849,4 +1865,128 @@ mod tests { "Iteration cap should transition to Failed, not Stuck" ); } + + /// Regression test: selections_to_tool_calls must preserve tool_call_id + /// so that tool_result messages match the assistant_with_tool_calls message + /// and are not treated as orphaned by sanitize_tool_messages. + #[test] + fn test_selections_to_tool_calls_preserves_ids() { + let selections = vec![ + ToolSelection { + tool_name: "search".into(), + parameters: serde_json::json!({"q": "test"}), + reasoning: "Need to search".into(), + alternatives: vec![], + tool_call_id: "call_abc".into(), + }, + ToolSelection { + tool_name: "fetch".into(), + parameters: serde_json::json!({"url": "https://example.com"}), + reasoning: "Need to fetch".into(), + alternatives: vec![], + tool_call_id: "call_def".into(), + }, + ]; + + let tool_calls = selections_to_tool_calls(&selections); + + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].id, "call_abc"); + assert_eq!(tool_calls[0].name, "search"); + assert_eq!(tool_calls[1].id, "call_def"); + assert_eq!(tool_calls[1].name, "fetch"); + } + + /// Regression test: when select_tools returns selections with reasoning, + /// the reasoning text should be preserved as content in the RespondResult + /// so it appears in the assistant_with_tool_calls message. Without this, + /// the LLM's reasoning context is lost and subsequent turns lack context. + #[test] + fn test_reasoning_text_extraction_from_selections() { + // Simulate what call_llm does: extract first non-empty reasoning + let selections = [ + ToolSelection { + tool_name: "search".into(), + parameters: serde_json::json!({}), + reasoning: "I need to search for relevant information".into(), + alternatives: vec![], + tool_call_id: "call_1".into(), + }, + ToolSelection { + tool_name: "fetch".into(), + parameters: serde_json::json!({}), + reasoning: "I need to search for relevant information".into(), + alternatives: vec![], + tool_call_id: "call_2".into(), + }, + ]; + + let reasoning_text = selections + .iter() + .find_map(|sel| (!sel.reasoning.is_empty()).then_some(sel.reasoning.clone())); + + assert_eq!( + reasoning_text.as_deref(), + Some("I need to search for relevant information"), + "Reasoning text should be extracted from first non-empty selection" + ); + + // Empty reasoning should result in None + let empty_selections = [ToolSelection { + tool_name: "echo".into(), + parameters: serde_json::json!({}), + reasoning: String::new(), + alternatives: vec![], + tool_call_id: "call_3".into(), + }]; + + let empty_reasoning = empty_selections + .iter() + .find_map(|sel| (!sel.reasoning.is_empty()).then_some(sel.reasoning.clone())); + + assert!( + empty_reasoning.is_none(), + "Empty reasoning should not be included as content" + ); + } + + /// When the first selection has empty reasoning but a subsequent one has + /// non-empty reasoning, find_map should skip the empty one and return the + /// first non-empty reasoning. + #[test] + fn test_reasoning_text_skips_empty_first_selection() { + let selections = [ + ToolSelection { + tool_name: "echo".into(), + parameters: serde_json::json!({}), + reasoning: String::new(), + alternatives: vec![], + tool_call_id: "call_1".into(), + }, + ToolSelection { + tool_name: "search".into(), + parameters: serde_json::json!({}), + reasoning: "Found the answer in the second selection".into(), + alternatives: vec![], + tool_call_id: "call_2".into(), + }, + ToolSelection { + tool_name: "fetch".into(), + parameters: serde_json::json!({}), + reasoning: "Third selection reasoning".into(), + alternatives: vec![], + tool_call_id: "call_3".into(), + }, + ]; + + let reasoning_text = selections + .iter() + .find_map(|sel| (!sel.reasoning.is_empty()).then_some(sel.reasoning.clone())); + + assert_eq!( + reasoning_text.as_deref(), + Some("Found the answer in the second selection"), + "Should skip empty first reasoning and return the first non-empty one" + ); + } } diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index dced10ea8e..b19c77af1a 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -39,6 +39,9 @@ # Temp directory for the libSQL database file (cleaned up automatically) _DB_TMPDIR = tempfile.TemporaryDirectory(prefix="ironclaw-e2e-") +# Temp HOME so pairing/allowFrom state never touches the developer's real ~/.ironclaw +_HOME_TMPDIR = tempfile.TemporaryDirectory(prefix="ironclaw-e2e-home-") + # Temp directories for WASM extensions. These start empty and are populated by # the install pipeline during tests; fixtures do not pre-populate dev build # artifacts into them. @@ -46,6 +49,42 @@ _WASM_CHANNELS_TMPDIR = tempfile.TemporaryDirectory(prefix="ironclaw-e2e-wasm-channels-") +def _latest_mtime(path: Path) -> float: + """Return the newest mtime under a file or directory.""" + if not path.exists(): + return 0.0 + if path.is_file(): + return path.stat().st_mtime + + latest = path.stat().st_mtime + for root, dirnames, filenames in os.walk(path): + dirnames[:] = [dirname for dirname in dirnames if dirname != "target"] + for name in filenames: + child = Path(root) / name + try: + latest = max(latest, child.stat().st_mtime) + except FileNotFoundError: + continue + return latest + + +def _binary_needs_rebuild(binary: Path) -> bool: + """Rebuild when the binary is missing or older than embedded sources.""" + if not binary.exists(): + return True + + binary_mtime = binary.stat().st_mtime + inputs = [ + ROOT / "Cargo.toml", + ROOT / "Cargo.lock", + ROOT / "build.rs", + ROOT / "providers.json", + ROOT / "src", + ROOT / "channels-src", + ] + return any(_latest_mtime(path) > binary_mtime for path in inputs) + + def _find_free_port() -> int: """Bind to port 0 and return the OS-assigned port.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -57,7 +96,7 @@ def _find_free_port() -> int: def ironclaw_binary(): """Ensure ironclaw binary is built. Returns the binary path.""" binary = ROOT / "target" / "debug" / "ironclaw" - if not binary.exists(): + if _binary_needs_rebuild(binary): print("Building ironclaw (this may take a while)...") subprocess.run( ["cargo", "build", "--no-default-features", "--features", "libsql"], @@ -141,10 +180,12 @@ def _wasm_build_symlinks(): async def ironclaw_server(ironclaw_binary, mock_llm_server, wasm_tools_dir): """Start the ironclaw gateway. Yields the base URL.""" gateway_port = _find_free_port() + home_dir = _HOME_TMPDIR.name env = { # Minimal env: PATH for process spawning, HOME for Rust/cargo defaults "PATH": os.environ.get("PATH", "/usr/bin:/bin"), - "HOME": os.environ.get("HOME", "/tmp"), + "HOME": home_dir, + "IRONCLAW_BASE_DIR": os.path.join(home_dir, ".ironclaw"), "RUST_LOG": "ironclaw=info", "RUST_BACKTRACE": "1", "GATEWAY_ENABLED": "true", diff --git a/tests/e2e/scenarios/test_telegram_hot_activation.py b/tests/e2e/scenarios/test_telegram_hot_activation.py new file mode 100644 index 0000000000..833803d650 --- /dev/null +++ b/tests/e2e/scenarios/test_telegram_hot_activation.py @@ -0,0 +1,236 @@ +"""Telegram hot-activation UI coverage.""" + +import asyncio +import json + +from helpers import SEL + +_CONFIGURE_SECRET_INPUT = "input[type='password']" +_CONFIGURE_SAVE_BUTTON = ".configure-actions button.btn-ext.activate" + + +_TELEGRAM_INSTALLED = { + "name": "telegram", + "display_name": "Telegram", + "kind": "wasm_channel", + "description": "Telegram Bot API channel", + "url": None, + "active": False, + "authenticated": False, + "has_auth": False, + "needs_setup": True, + "tools": [], + "activation_status": "installed", + "activation_error": None, +} + +_TELEGRAM_ACTIVE = { + **_TELEGRAM_INSTALLED, + "active": True, + "authenticated": True, + "needs_setup": False, + "activation_status": "active", +} + + +async def go_to_extensions(page): + await page.locator(SEL["tab_button"].format(tab="extensions")).click() + await page.locator(SEL["tab_panel"].format(tab="extensions")).wait_for( + state="visible", timeout=5000 + ) + await page.locator( + f"{SEL['extensions_list']} .empty-state, {SEL['ext_card_installed']}" + ).first.wait_for(state="visible", timeout=8000) + + +async def mock_extension_lists(page, ext_handler): + async def handle_ext_list(route): + path = route.request.url.split("?")[0] + if path.endswith("/api/extensions"): + await ext_handler(route) + else: + await route.continue_() + + async def handle_tools(route): + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"tools": []}), + ) + + async def handle_registry(route): + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"entries": []}), + ) + + # Register the broad route first so the specific endpoints below win. + await page.route("**/api/extensions*", handle_ext_list) + await page.route("**/api/extensions/tools", handle_tools) + await page.route("**/api/extensions/registry", handle_registry) + + +async def wait_for_toast(page, text: str, *, timeout: int = 5000): + await page.locator(SEL["toast"], has_text=text).wait_for( + state="visible", timeout=timeout + ) + + +async def test_telegram_setup_modal_shows_bot_token_field(page): + async def handle_ext_list(route): + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"extensions": [_TELEGRAM_INSTALLED]}), + ) + + async def handle_setup(route): + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps( + { + "secrets": [ + { + "name": "telegram_bot_token", + "prompt": "Enter your Telegram Bot API token (from @BotFather)", + "provided": False, + "optional": False, + "auto_generate": False, + } + ] + } + ), + ) + + await mock_extension_lists(page, handle_ext_list) + await page.route("**/api/extensions/telegram/setup", handle_setup) + await go_to_extensions(page) + + card = page.locator(SEL["ext_card_installed"]).first + await card.locator(SEL["ext_configure_btn"], has_text="Setup").click() + + modal = page.locator(SEL["configure_modal"]) + await modal.wait_for(state="visible", timeout=5000) + assert "Telegram Bot API token" in await modal.text_content() + assert "IronClaw will show a one-time code" in ( + await modal.text_content() + ) + input_el = modal.locator(_CONFIGURE_SECRET_INPUT) + assert await input_el.count() == 1 + + +async def test_telegram_hot_activation_transitions_installed_to_active(page): + phase = {"value": "installed"} + captured_setup_payloads = [] + post_count = {"value": 0} + + async def handle_ext_list(route): + extensions = { + "installed": [_TELEGRAM_INSTALLED], + "active": [_TELEGRAM_ACTIVE], + }[phase["value"]] + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"extensions": extensions}), + ) + + async def handle_setup(route): + if route.request.method == "GET": + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps( + { + "secrets": [ + { + "name": "telegram_bot_token", + "prompt": "Enter your Telegram Bot API token (from @BotFather)", + "provided": False, + "optional": False, + "auto_generate": False, + } + ] + } + ), + ) + return + + payload = json.loads(route.request.post_data or "{}") + captured_setup_payloads.append(payload) + post_count["value"] += 1 + await asyncio.sleep(0.05) + if post_count["value"] == 1: + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps( + { + "success": True, + "activated": False, + "message": "Configuration saved for 'telegram'. Send `/start iclaw-7qk2m9` to @test_hot_bot, then click Verify owner.", + "verification": { + "code": "iclaw-7qk2m9", + "instructions": "Send `/start iclaw-7qk2m9` to @test_hot_bot, then click Verify owner.", + "deep_link": "https://t.me/test_hot_bot?start=iclaw-7qk2m9", + }, + } + ), + ) + else: + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps( + { + "success": True, + "activated": True, + "message": "Configuration saved, Telegram owner verified, and 'telegram' activated. Hot-activated WASM channel", + } + ), + ) + + await mock_extension_lists(page, handle_ext_list) + await page.route("**/api/extensions/telegram/setup", handle_setup) + await go_to_extensions(page) + + card = page.locator(SEL["ext_card_installed"]).first + await card.locator(SEL["ext_configure_btn"], has_text="Setup").click() + + modal = page.locator(SEL["configure_modal"]) + await modal.wait_for(state="visible", timeout=5000) + await modal.locator(_CONFIGURE_SECRET_INPUT).fill("123456789:ABCdefGhI") + await modal.locator(_CONFIGURE_SAVE_BUTTON).click() + await modal.locator(_CONFIGURE_SAVE_BUTTON, has_text="Verify owner").wait_for( + state="visible", timeout=5000 + ) + assert "Verify owner" in ( + await modal.locator(_CONFIGURE_SAVE_BUTTON).text_content() + ) + assert "iclaw-7qk2m9" in (await modal.text_content()) + assert await modal.locator(".configure-verification-link").count() == 1 + + await modal.locator(_CONFIGURE_SAVE_BUTTON).click() + await page.locator(SEL["configure_overlay"]).wait_for(state="hidden", timeout=5000) + + phase["value"] = "active" + await page.evaluate( + """ + handleAuthCompleted({ + extension_name: 'telegram', + success: true, + message: "Configuration saved, Telegram owner verified, and 'telegram' activated. Hot-activated WASM channel", + }); + """ + ) + + await wait_for_toast(page, "Telegram owner verified") + await card.locator(SEL["ext_active_label"]).wait_for(state="visible", timeout=5000) + assert await card.locator(SEL["ext_pairing_label"]).count() == 0 + + assert captured_setup_payloads == [ + {"secrets": {"telegram_bot_token": "123456789:ABCdefGhI"}}, + {"secrets": {}}, + ] diff --git a/tests/e2e/scenarios/test_telegram_token_validation.py b/tests/e2e/scenarios/test_telegram_token_validation.py new file mode 100644 index 0000000000..69d04e51f4 --- /dev/null +++ b/tests/e2e/scenarios/test_telegram_token_validation.py @@ -0,0 +1,172 @@ +"""Scenario: Telegram bot token validation - configure modal UI test. + +Tests the Telegram extension configure modal renders and accepts tokens with colons. + +Note: The core URL-building logic (colon preservation, no %3A encoding) is verified +by unit tests in src/extensions/manager.rs. This E2E test verifies the configure modal +UI can accept Telegram tokens with colons and renders correctly. +""" + +import json + +from helpers import SEL + + +# ─── Fixture data ───────────────────────────────────────────────────────────── + +_TELEGRAM_EXTENSION = { + "name": "telegram", + "display_name": "Telegram", + "kind": "wasm_channel", + "description": "Telegram bot channel", + "url": None, + "active": False, + "authenticated": False, + "has_auth": True, + "needs_setup": True, + "tools": [], + "activation_status": "installed", + "activation_error": None, +} + +_TELEGRAM_SECRETS = [ + { + "name": "telegram_bot_token", + "prompt": "Telegram Bot Token", + "provided": False, + "optional": False, + "auto_generate": False, + } +] + + +# ─── Tests ──────────────────────────────────────────────────────────────────── + +async def test_telegram_configure_modal_renders(page): + """ + Telegram extension configure modal renders with correct fields. + + Verifies that the configure modal appears with the Telegram bot token field + and all expected UI elements are present. + """ + ext_body = json.dumps({"extensions": [_TELEGRAM_EXTENSION]}) + + async def handle_ext_list(route): + if route.request.url.endswith("/api/extensions"): + await route.fulfill( + status=200, content_type="application/json", body=ext_body + ) + else: + await route.continue_() + + await page.route("**/api/extensions*", handle_ext_list) + + async def handle_setup(route): + if route.request.method == "GET": + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"secrets": _TELEGRAM_SECRETS}), + ) + else: + await route.continue_() + + await page.route("**/api/extensions/telegram/setup", handle_setup) + await page.evaluate("showConfigureModal('telegram')") + modal = page.locator(SEL["configure_modal"]) + await modal.wait_for(state="visible", timeout=5000) + + # Modal should contain the extension name and token prompt + modal_text = await modal.text_content() + assert "telegram" in modal_text.lower() + assert "bot token" in modal_text.lower() + + # Input field should be present + input_field = page.locator(SEL["configure_input"]) + assert await input_field.is_visible() + + +async def test_telegram_token_input_accepts_colon_format(page): + """ + Telegram bot token input accepts tokens with colon separator. + + Verifies that a token in the format `numeric_id:alphanumeric_string` + can be entered without browser-side validation errors. + """ + ext_body = json.dumps({"extensions": [_TELEGRAM_EXTENSION]}) + + async def handle_ext_list(route): + if route.request.url.endswith("/api/extensions"): + await route.fulfill( + status=200, content_type="application/json", body=ext_body + ) + else: + await route.continue_() + + await page.route("**/api/extensions*", handle_ext_list) + + async def handle_setup(route): + if route.request.method == "GET": + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"secrets": _TELEGRAM_SECRETS}), + ) + + await page.route("**/api/extensions/telegram/setup", handle_setup) + await page.evaluate("showConfigureModal('telegram')") + await page.locator(SEL["configure_modal"]).wait_for(state="visible", timeout=5000) + + # Enter a valid Telegram bot token with colon + token_value = "123456789:AABBccDDeeFFgg_Test-Token" + input_field = page.locator(SEL["configure_input"]) + await input_field.fill(token_value) + + # Verify the value was entered and colon is preserved + entered_value = await input_field.input_value() + assert entered_value == token_value + assert ":" in entered_value, "Colon should be preserved in token" + assert "%3A" not in entered_value, "Colon should not be URL-encoded in input" + + +async def test_telegram_token_with_underscores_and_hyphens(page): + """ + Telegram tokens with hyphens and underscores are accepted. + + Verifies that valid Telegram token characters (hyphens, underscores) are + properly accepted by the input field. + """ + ext_body = json.dumps({"extensions": [_TELEGRAM_EXTENSION]}) + + async def handle_ext_list(route): + if route.request.url.endswith("/api/extensions"): + await route.fulfill( + status=200, content_type="application/json", body=ext_body + ) + else: + await route.continue_() + + await page.route("**/api/extensions*", handle_ext_list) + + async def handle_setup(route): + if route.request.method == "GET": + await route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"secrets": _TELEGRAM_SECRETS}), + ) + + await page.route("**/api/extensions/telegram/setup", handle_setup) + await page.evaluate("showConfigureModal('telegram')") + await page.locator(SEL["configure_modal"]).wait_for(state="visible", timeout=5000) + + # Token with hyphens and underscores + token_value = "987654321:ABCD-EFgh_ijkl-MNOP_qrst" + input_field = page.locator(SEL["configure_input"]) + await input_field.fill(token_value) + + # Verify the value was entered correctly with all characters preserved + entered_value = await input_field.input_value() + assert entered_value == token_value + assert "-" in entered_value + assert "_" in entered_value diff --git a/tests/e2e_routine_heartbeat.rs b/tests/e2e_routine_heartbeat.rs index f5a28c25b6..6d6deb8bec 100644 --- a/tests/e2e_routine_heartbeat.rs +++ b/tests/e2e_routine_heartbeat.rs @@ -218,18 +218,7 @@ mod tests { engine.refresh_event_cache().await; // Positive match: message containing "deploy to production". - let matching_msg = IncomingMessage { - id: Uuid::new_v4(), - channel: "test".to_string(), - user_id: "default".to_string(), - user_name: None, - content: "deploy to production now".to_string(), - thread_id: None, - received_at: Utc::now(), - metadata: serde_json::json!({}), - timezone: None, - attachments: Vec::new(), - }; + let matching_msg = IncomingMessage::new("test", "default", "deploy to production now"); let fired = engine.check_event_triggers(&matching_msg).await; assert!( fired >= 1, @@ -240,18 +229,8 @@ mod tests { tokio::time::sleep(Duration::from_millis(500)).await; // Negative match: message that doesn't match. - let non_matching_msg = IncomingMessage { - id: Uuid::new_v4(), - channel: "test".to_string(), - user_id: "default".to_string(), - user_name: None, - content: "check the staging environment".to_string(), - thread_id: None, - received_at: Utc::now(), - metadata: serde_json::json!({}), - timezone: None, - attachments: Vec::new(), - }; + let non_matching_msg = + IncomingMessage::new("test", "default", "check the staging environment"); let fired_neg = engine.check_event_triggers(&non_matching_msg).await; assert_eq!(fired_neg, 0, "Expected 0 routines fired on non-match"); } @@ -455,18 +434,7 @@ mod tests { engine.refresh_event_cache().await; // First fire should work. - let msg = IncomingMessage { - id: Uuid::new_v4(), - channel: "test".to_string(), - user_id: "default".to_string(), - user_name: None, - content: "test-cooldown trigger".to_string(), - thread_id: None, - received_at: Utc::now(), - metadata: serde_json::json!({}), - timezone: None, - attachments: Vec::new(), - }; + let msg = IncomingMessage::new("test", "default", "test-cooldown trigger"); let fired1 = engine.check_event_triggers(&msg).await; assert!(fired1 >= 1, "First fire should work"); diff --git a/tests/support/gateway_workflow_harness.rs b/tests/support/gateway_workflow_harness.rs index dd9e86430d..c539dad504 100644 --- a/tests/support/gateway_workflow_harness.rs +++ b/tests/support/gateway_workflow_harness.rs @@ -143,6 +143,9 @@ impl GatewayWorkflowHarness { model: model.to_string(), extra_headers: Vec::new(), oauth_token: None, + is_codex_chatgpt: false, + refresh_token: None, + auth_path: None, cache_retention: Default::default(), unsupported_params: Vec::new(), }); diff --git a/tests/telegram_auth_integration.rs b/tests/telegram_auth_integration.rs index 01d246a64b..8b27d8a8c8 100644 --- a/tests/telegram_auth_integration.rs +++ b/tests/telegram_auth_integration.rs @@ -40,8 +40,31 @@ macro_rules! require_telegram_wasm { /// Path to the built Telegram WASM module fn telegram_wasm_path() -> std::path::PathBuf { - std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("channels-src/telegram/target/wasm32-wasip2/release/telegram_channel.wasm") + let local = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("channels-src/telegram/target/wasm32-wasip2/release/telegram_channel.wasm"); + if local.exists() { + return local; + } + + if let Ok(output) = std::process::Command::new("git") + .args(["worktree", "list", "--porcelain"]) + .output() + && output.status.success() + { + let stdout = String::from_utf8_lossy(&output.stdout); + for line in stdout.lines() { + if let Some(path) = line.strip_prefix("worktree ") { + let candidate = std::path::PathBuf::from(path).join( + "channels-src/telegram/target/wasm32-wasip2/release/telegram_channel.wasm", + ); + if candidate.exists() { + return candidate; + } + } + } + } + + local } /// Create a test runtime for WASM channel operations.