diff --git a/.githooks/pre-push b/.githooks/pre-push index cd6b5cd413..e9c7d8dada 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -1,23 +1,18 @@ #!/usr/bin/env bash set -euo pipefail +# Pre-push hook: runs quality gate before pushing +# Skip with: git push --no-verify -# Pre-push hook: run clippy and tests before pushing. -# Install: git config core.hooksPath .githooks +REPO_ROOT="$(git rev-parse --show-toplevel)" +SCRIPT_DIR="$REPO_ROOT/scripts/ci" -echo "pre-push: running clippy..." -if ! cargo clippy --all --benches --tests --examples --all-features -- -D warnings; then - echo "" - echo "Push blocked: clippy warnings found." - echo "To bypass: git push --no-verify" - exit 1 -fi +# Default: baseline quality gate +"$SCRIPT_DIR/quality_gate.sh" -echo "pre-push: running tests..." -if ! cargo test; then - echo "" - echo "Push blocked: tests failed." - echo "To bypass: git push --no-verify" - exit 1 +# Optional strict delta lint (env-gated) +if [ "${IRONCLAW_STRICT_DELTA_LINT:-0}" = "1" ]; then + "$SCRIPT_DIR/delta_lint.sh" "$1" +elif [ "${IRONCLAW_STRICT_LINT:-0}" = "1" ]; then + echo "==> clippy (strict: all warnings)" + cargo clippy --locked --all-targets -- -D warnings fi - -echo "pre-push: all checks passed." diff --git a/FEATURE_PARITY.md b/FEATURE_PARITY.md index 6e308a10a7..db4ab92a4c 100644 --- a/FEATURE_PARITY.md +++ b/FEATURE_PARITY.md @@ -176,7 +176,7 @@ This document tracks feature parity between IronClaw (Rust implementation) and O | `browser` | ✅ | ❌ | P3 | Browser automation | | `sandbox` | ✅ | ✅ | - | WASM sandbox | | `doctor` | ✅ | 🚧 | P2 | 16 subsystem checks | -| `logs` | ✅ | ❌ | P3 | Query logs | +| `logs` | ✅ | 🚧 | P3 | `logs` (gateway.log tail), `--follow` (SSE live stream), `--level` (get/set). No DB-persisted log history. | | `update` | ✅ | ❌ | P3 | Self-update | | `completion` | ✅ | ✅ | - | Shell completion | | `/subagents spawn` | ✅ | ❌ | P3 | Spawn subagents from chat | diff --git a/scripts/ci/delta_lint.sh b/scripts/ci/delta_lint.sh new file mode 100755 index 0000000000..c64b91a7ec --- /dev/null +++ b/scripts/ci/delta_lint.sh @@ -0,0 +1,216 @@ +#!/usr/bin/env bash +set -euo pipefail +# Delta lint: only fail on clippy warnings/errors that touch changed lines. +# Compares the current branch against the merge base with the upstream default branch. + +CLIPPY_OUT="" +DIFF_OUT="" +CLIPPY_STDERR="" + +cleanup() { + [ -n "$CLIPPY_OUT" ] && rm -f "$CLIPPY_OUT" + [ -n "$DIFF_OUT" ] && rm -f "$DIFF_OUT" + [ -n "$CLIPPY_STDERR" ] && rm -f "$CLIPPY_STDERR" +} +trap cleanup EXIT + +# Verify python3 is available (needed for diagnostic filtering) +if ! command -v python3 &>/dev/null; then + echo "ERROR: python3 is required for delta lint but not found" + exit 1 +fi + +# Accept optional remote name argument; default to dynamic detection +REMOTE="${1:-}" + +# Determine the upstream base ref dynamically +BASE_REF="" +if [ -n "$REMOTE" ]; then + # Use the provided remote name + if [ -z "$BASE_REF" ]; then + BASE_REF=$(git symbolic-ref "refs/remotes/$REMOTE/HEAD" 2>/dev/null | sed 's|refs/remotes/||' || true) + fi + if [ -z "$BASE_REF" ] && git rev-parse --verify "$REMOTE/main" &>/dev/null; then + BASE_REF="$REMOTE/main" + fi + if [ -z "$BASE_REF" ] && git rev-parse --verify "$REMOTE/master" &>/dev/null; then + BASE_REF="$REMOTE/master" + fi +else + # Try the remote HEAD symbolic ref (works for any default branch name) + if [ -z "$BASE_REF" ]; then + BASE_REF=$(git symbolic-ref refs/remotes/origin/HEAD 2>/dev/null | sed 's|refs/remotes/||' || true) + fi + # Fall back to common default branch names + if [ -z "$BASE_REF" ] && git rev-parse --verify origin/main &>/dev/null; then + BASE_REF="origin/main" + fi + if [ -z "$BASE_REF" ] && git rev-parse --verify origin/master &>/dev/null; then + BASE_REF="origin/master" + fi +fi +if [ -z "$BASE_REF" ]; then + echo "WARNING: could not determine upstream base branch, skipping delta lint" + exit 0 +fi + +# Compute merge base +BASE=$(git merge-base "$BASE_REF" HEAD 2>/dev/null) || { + echo "WARNING: git merge-base failed for $BASE_REF, skipping delta lint" + exit 0 +} + +# Find changed .rs files +CHANGED_RS=$(git diff --name-only "$BASE" -- '*.rs' || true) +if [ -z "$CHANGED_RS" ]; then + echo "==> delta lint: no .rs files changed, skipping" + exit 0 +fi + +echo "==> delta lint: checking changed lines since $(echo "$BASE" | head -c 10)..." + +# Extract unified-0 diff for changed line ranges +DIFF_OUT=$(mktemp "${TMPDIR:-/tmp}/ironclaw-diff.XXXXXX") +git diff --unified=0 "$BASE" -- '*.rs' > "$DIFF_OUT" + +# Run clippy with JSON output (stderr shows compilation progress/errors) +CLIPPY_OUT=$(mktemp "${TMPDIR:-/tmp}/ironclaw-clippy.XXXXXX") +CLIPPY_STDERR=$(mktemp "${TMPDIR:-/tmp}/ironclaw-clippy-err.XXXXXX") +cargo clippy --locked --all-targets --message-format=json > "$CLIPPY_OUT" 2>"$CLIPPY_STDERR" || true + +# Show compilation errors if clippy produced no JSON output +if [ ! -s "$CLIPPY_OUT" ] && [ -s "$CLIPPY_STDERR" ]; then + echo "ERROR: clippy failed to produce output. Compilation errors:" + cat "$CLIPPY_STDERR" + exit 1 +fi + +# Get repo root for path normalization in Python +REPO_ROOT="$(git rev-parse --show-toplevel)" + +# Filter clippy diagnostics against changed line ranges +python3 - "$DIFF_OUT" "$CLIPPY_OUT" "$REPO_ROOT" <<'PYEOF' +import json +import re +import sys +import os + +def parse_diff(diff_path): + """Parse unified-0 diff to extract {file: [[start, end], ...]} changed ranges.""" + changed = {} + current_file = None + with open(diff_path) as f: + for line in f: + # Match +++ b/path/to/file.rs or +++ /dev/null (deletion) + if line.startswith('+++ /dev/null'): + current_file = None + continue + m = re.match(r'^\+\+\+ b/(.+)$', line) + if m: + current_file = m.group(1) + if current_file not in changed: + changed[current_file] = [] + continue + # Match @@ hunk headers: @@ -old,count +new,count @@ + m = re.match(r'^@@ .+ \+(\d+)(?:,(\d+))? @@', line) + if m and current_file: + start = int(m.group(1)) + count = int(m.group(2)) if m.group(2) is not None else 1 + if count == 0: + continue + end = start + count - 1 + changed[current_file].append([start, end]) + return changed + +def normalize_path(path, repo_root): + """Normalize absolute path to relative (from repo root).""" + if os.path.isabs(path): + if path.startswith(repo_root): + return os.path.relpath(path, repo_root) + return path + +def in_changed_range(file_path, line_start, line_end, changed_ranges, repo_root): + """Check if file:[line_start, line_end] overlaps any changed range.""" + rel = normalize_path(file_path, repo_root) + ranges = changed_ranges.get(rel) + if not ranges: + return False + return any(start <= line_end and line_start <= end for start, end in ranges) + +def main(): + diff_path = sys.argv[1] + clippy_path = sys.argv[2] + repo_root = sys.argv[3] + + changed_ranges = parse_diff(diff_path) + + blocking = [] + baseline = [] + + with open(clippy_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + msg = json.loads(line) + except json.JSONDecodeError: + continue + + if msg.get("reason") != "compiler-message": + continue + + cm = msg.get("message", {}) + level = cm.get("level", "") + if level not in ("warning", "error"): + continue + + rendered = cm.get("rendered", "").strip() + + # Errors are always blocking regardless of location + if level == "error": + blocking.append(rendered) + continue + + # For warnings, only block if they overlap changed lines + spans = cm.get("spans", []) + primary = None + for s in spans: + if s.get("is_primary"): + primary = s + break + if not primary: + if spans: + primary = spans[0] + else: + baseline.append(rendered) + continue + + file_name = primary.get("file_name", "") + line_start = primary.get("line_start", 0) + line_end = primary.get("line_end", line_start) + + if in_changed_range(file_name, line_start, line_end, changed_ranges, repo_root): + blocking.append(rendered) + else: + baseline.append(rendered) + + if baseline: + print(f"\n--- Baseline warnings (not in changed lines, informational) [{len(baseline)}] ---") + for w in baseline[:10]: + print(w) + if len(baseline) > 10: + print(f" ... and {len(baseline) - 10} more") + + if blocking: + print(f"\n*** BLOCKING: {len(blocking)} issue(s) in changed lines ***") + for w in blocking: + print(w) + sys.exit(1) + else: + print("\n==> delta lint: passed (no issues in changed lines)") + sys.exit(0) + +if __name__ == "__main__": + main() +PYEOF diff --git a/scripts/ci/quality_gate.sh b/scripts/ci/quality_gate.sh new file mode 100755 index 0000000000..83a62e0290 --- /dev/null +++ b/scripts/ci/quality_gate.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +echo "==> fmt check" +cargo fmt --all -- --check + +echo "==> clippy (correctness)" +cargo clippy --locked --all-targets -- -D clippy::correctness + +if [ "${IRONCLAW_PREPUSH_TEST:-1}" = "1" ]; then + echo "==> tests (skip with IRONCLAW_PREPUSH_TEST=0)" + cargo test --locked --lib +fi diff --git a/scripts/dev-setup.sh b/scripts/dev-setup.sh index faa5aa2c68..4d272f49e0 100755 --- a/scripts/dev-setup.sh +++ b/scripts/dev-setup.sh @@ -56,6 +56,9 @@ if [ -n "$HOOKS_DIR" ]; then echo " commit-msg hook installed (regression test enforcement)" ln -sf "$SCRIPTS_ABS/pre-commit-safety.sh" "$HOOKS_DIR/pre-commit" echo " pre-commit hook installed (UTF-8, case-sensitivity, /tmp, redaction checks)" + REPO_ROOT="$(git rev-parse --show-toplevel)" + ln -sf "$REPO_ROOT/.githooks/pre-push" "$HOOKS_DIR/pre-push" + echo " pre-push hook installed (quality gate + optional delta lint)" else echo " Skipped: not a git repository" fi diff --git a/scripts/pre-commit-safety.sh b/scripts/pre-commit-safety.sh index a4ec3286f3..7f1667dc91 100755 --- a/scripts/pre-commit-safety.sh +++ b/scripts/pre-commit-safety.sh @@ -136,6 +136,14 @@ fi PROD_DIFF="$DIFF_OUTPUT" # Strip hunks from test-only files (tests/ directory, *_test.rs, test_*.rs) PROD_DIFF=$(echo "$PROD_DIFF" | grep -v '^+++ b/tests/' || true) +# Strip hunks whose @@ context line indicates a test module. +# git diff includes the enclosing function/module name after @@. +# Only match `mod tests` (the conventional #[cfg(test)] module) — do NOT +# match `fn test_*` because production code can have functions named test_*. +PROD_DIFF=$(echo "$PROD_DIFF" | awk ' + /^@@ / { in_test = ($0 ~ /mod tests/) } + !in_test { print } +' || true) if echo "$PROD_DIFF" | grep -nE '^\+' \ | grep -E '\.(unwrap|expect)\(|[^_]assert(_eq|_ne)?!' \ | grep -vE 'debug_assert|// safety:|#\[cfg\(test\)\]|#\[test\]|mod tests' \ diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 8fda4143fa..5ca094e41a 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -838,19 +838,42 @@ impl Agent { }; if let Some(pending) = pending_auth { - match &submission { - Submission::UserInput { content } => { - return self - .process_auth_token(message, &pending, content, session, thread_id) - .await; - } - _ => { - // Any control submission (interrupt, undo, etc.) cancels auth mode + if pending.is_expired() { + // TTL exceeded — clear stale auth mode + tracing::warn!( + extension = %pending.extension_name, + "Auth mode expired after TTL, clearing" + ); + { let mut sess = session.lock().await; if let Some(thread) = sess.threads.get_mut(&thread_id) { thread.pending_auth = None; } - // Fall through to normal handling + } + // If this was a user message (possibly a pasted token), return an + // explicit error instead of forwarding it to the LLM/history. + if matches!(submission, Submission::UserInput { .. }) { + return Ok(Some(format!( + "Authentication for **{}** expired. Please try again.", + pending.extension_name + ))); + } + // Control submissions (interrupt, undo, etc.) fall through to normal handling + } else { + match &submission { + Submission::UserInput { content } => { + return self + .process_auth_token(message, &pending, content, session, thread_id) + .await; + } + _ => { + // Any control submission (interrupt, undo, etc.) cancels auth mode + let mut sess = session.lock().await; + if let Some(thread) = sess.threads.get_mut(&thread_id) { + thread.pending_auth = None; + } + // Fall through to normal handling + } } } } diff --git a/src/agent/session.rs b/src/agent/session.rs index 0c1f1fd3bd..4abbea6168 100644 --- a/src/agent/session.rs +++ b/src/agent/session.rs @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet}; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeDelta, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -135,6 +135,12 @@ pub enum ThreadState { /// Pending auth token request. /// +/// Auth mode TTL — must stay in sync with +/// `crate::cli::oauth_defaults::OAUTH_FLOW_EXPIRY` (5 minutes / 300 s). +/// Defined separately to avoid a session→cli module dependency. +const AUTH_MODE_TTL_SECS: i64 = 300; +const AUTH_MODE_TTL: TimeDelta = TimeDelta::seconds(AUTH_MODE_TTL_SECS); + /// When `tool_auth` returns `awaiting_token`, the thread enters auth mode. /// The next user message is intercepted before entering the normal pipeline /// (no logging, no turn creation, no history) and routed directly to the @@ -143,6 +149,16 @@ pub enum ThreadState { pub struct PendingAuth { /// Extension name to authenticate. pub extension_name: String, + /// When this auth mode was entered. Used for TTL expiry. + #[serde(default = "Utc::now")] + pub created_at: DateTime, +} + +impl PendingAuth { + /// Returns `true` if this auth mode has exceeded the TTL. + pub fn is_expired(&self) -> bool { + Utc::now() - self.created_at > AUTH_MODE_TTL + } } /// Pending tool approval request stored on a thread. @@ -298,7 +314,10 @@ impl Thread { /// Enter auth mode: next user message will be routed directly to /// the credential store, bypassing the normal pipeline entirely. pub fn enter_auth_mode(&mut self, extension_name: String) { - self.pending_auth = Some(PendingAuth { extension_name }); + self.pending_auth = Some(PendingAuth { + extension_name, + created_at: Utc::now(), + }); self.updated_at = Utc::now(); } @@ -687,15 +706,16 @@ mod tests { #[test] fn test_enter_auth_mode() { + let before = Utc::now(); let mut thread = Thread::new(Uuid::new_v4()); assert!(thread.pending_auth.is_none()); thread.enter_auth_mode("telegram".to_string()); assert!(thread.pending_auth.is_some()); - assert_eq!( - thread.pending_auth.as_ref().unwrap().extension_name, - "telegram" - ); + let pending = thread.pending_auth.as_ref().unwrap(); + assert_eq!(pending.extension_name, "telegram"); + assert!(pending.created_at >= before); + assert!(!pending.is_expired()); } #[test] @@ -705,8 +725,9 @@ mod tests { let pending = thread.take_pending_auth(); assert!(pending.is_some()); - assert_eq!(pending.unwrap().extension_name, "notion"); - + let pending = pending.unwrap(); + assert_eq!(pending.extension_name, "notion"); + assert!(!pending.is_expired()); // Should be cleared after take assert!(thread.pending_auth.is_none()); assert!(thread.take_pending_auth().is_none()); @@ -720,10 +741,25 @@ mod tests { let json = serde_json::to_string(&thread).expect("should serialize"); assert!(json.contains("pending_auth")); assert!(json.contains("openai")); + assert!(json.contains("created_at")); let restored: Thread = serde_json::from_str(&json).expect("should deserialize"); assert!(restored.pending_auth.is_some()); - assert_eq!(restored.pending_auth.unwrap().extension_name, "openai"); + let pending = restored.pending_auth.unwrap(); + assert_eq!(pending.extension_name, "openai"); + assert!(!pending.is_expired()); + } + + #[test] + fn test_pending_auth_expiry() { + let mut pending = PendingAuth { + extension_name: "test".to_string(), + created_at: Utc::now(), + }; + assert!(!pending.is_expired()); + // Backdate beyond the TTL + pending.created_at = Utc::now() - AUTH_MODE_TTL - TimeDelta::seconds(1); + assert!(pending.is_expired()); } #[test] diff --git a/src/channels/web/server.rs b/src/channels/web/server.rs index acec384235..97d3293327 100644 --- a/src/channels/web/server.rs +++ b/src/channels/web/server.rs @@ -526,23 +526,33 @@ async fn oauth_callback_handler( .get("error_description") .cloned() .unwrap_or_else(|| error.clone()); + clear_auth_mode(&state).await; return oauth_error_page(&description); } let state_param = match params.get("state") { Some(s) if !s.is_empty() => s.clone(), - _ => return oauth_error_page("IronClaw"), + _ => { + clear_auth_mode(&state).await; + return oauth_error_page("IronClaw"); + } }; let code = match params.get("code") { Some(c) if !c.is_empty() => c.clone(), - _ => return oauth_error_page("IronClaw"), + _ => { + clear_auth_mode(&state).await; + return oauth_error_page("IronClaw"); + } }; // Look up the pending flow by CSRF state (atomic remove prevents replay) let ext_mgr = match state.extension_manager.as_ref() { Some(mgr) => mgr, - None => return oauth_error_page("IronClaw"), + None => { + clear_auth_mode(&state).await; + return oauth_error_page("IronClaw"); + } }; // Strip instance prefix from state for registry lookup. @@ -563,6 +573,7 @@ async fn oauth_callback_handler( lookup_key = %lookup_key, "OAuth callback received with unknown or expired state" ); + clear_auth_mode(&state).await; return oauth_error_page("IronClaw"); } }; @@ -581,6 +592,7 @@ async fn oauth_callback_handler( message: "OAuth flow expired. Please try again.".to_string(), }); } + clear_auth_mode(&state).await; return oauth_error_page(&flow.display_name); } @@ -690,6 +702,10 @@ async fn oauth_callback_handler( } } + // Clear auth mode regardless of outcome so the next user message goes + // through to the LLM instead of being intercepted as a token. + clear_auth_mode(&state).await; + // After successful OAuth, auto-activate the extension so it moves // from "Installed (Authenticate)" → "Active" without a second click. // OAuth success is independent of activation — tokens are already stored. @@ -2182,6 +2198,10 @@ async fn extensions_setup_submit_handler( "Extension manager not available (secrets store required)".to_string(), ))?; + // Clear auth mode regardless of outcome so the next user message goes + // through to the LLM instead of being intercepted as a token. + clear_auth_mode(&state).await; + match ext_mgr.configure(&name, &req.secrets).await { Ok(result) => { // Broadcast auth_completed so the chat UI can dismiss any in-progress diff --git a/src/channels/web/static/app.js b/src/channels/web/static/app.js index 081b0f3a5a..0624d07a3b 100644 --- a/src/channels/web/static/app.js +++ b/src/channels/web/static/app.js @@ -600,6 +600,22 @@ document.getElementById('chat-input').addEventListener('paste', (e) => { } }); +const chatMessagesEl = document.getElementById('chat-messages'); +chatMessagesEl.addEventListener('copy', (e) => { + const selection = window.getSelection(); + if (!selection || selection.isCollapsed) return; + const anchorNode = selection.anchorNode; + const focusNode = selection.focusNode; + if (!anchorNode || !focusNode) return; + if (!chatMessagesEl.contains(anchorNode) || !chatMessagesEl.contains(focusNode)) return; + const text = selection.toString(); + if (!text || !e.clipboardData) return; + // Force plain-text clipboard output so dark-theme styling never leaks on paste. + e.preventDefault(); + e.clipboardData.clearData(); + e.clipboardData.setData('text/plain', text); +}); + function addGeneratedImage(dataUrl, path) { const container = document.getElementById('chat-messages'); const card = document.createElement('div'); @@ -1759,7 +1775,10 @@ chatInput.addEventListener('keydown', (e) => { } } - if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { + // Safari fires compositionend before keydown, so e.isComposing is already false + // when Enter confirms IME input. keyCode 229 (VK_PROCESS) catches this case. + // See https://bugs.webkit.org/show_bug.cgi?id=165004 + if (e.key === 'Enter' && !e.shiftKey && !e.isComposing && e.keyCode !== 229) { e.preventDefault(); hideSlashAutocomplete(); sendMessage(); diff --git a/src/cli/logs.rs b/src/cli/logs.rs new file mode 100644 index 0000000000..651bf891bc --- /dev/null +++ b/src/cli/logs.rs @@ -0,0 +1,587 @@ +//! CLI command for viewing and managing gateway logs. +//! +//! Provides access to gateway logs through three mechanisms: +//! - Reading the gateway log file (`~/.ironclaw/gateway.log`) +//! - Streaming live logs via the gateway's SSE endpoint (`/api/logs/events`) +//! - Getting/setting the runtime log level via `/api/logs/level` + +use std::io::{Seek, SeekFrom}; +use std::path::Path; + +use clap::Args; + +/// View and manage gateway logs. +#[derive(Args, Debug, Clone)] +#[command( + about = "View and manage gateway logs", + long_about = "Tail gateway logs, stream live output, or adjust log level.\nExamples:\n ironclaw logs # Show last 200 lines\n ironclaw logs --follow # Stream live logs via SSE\n ironclaw logs --limit 50 --json # Last 50 lines as JSON\n ironclaw logs --level # Show current log level\n ironclaw logs --level debug # Set log level to debug" +)] +pub struct LogsCommand { + /// Stream live logs from the running gateway via SSE. + /// Replays recent history then streams new entries in real time. + #[arg(short, long)] + pub follow: bool, + + /// Maximum number of lines to show (default: 200) + #[arg(short, long, default_value = "200")] + pub limit: usize, + + /// Output log entries as JSON (one object per line) + #[arg(long)] + pub json: bool, + + /// Display timestamps in local timezone + #[arg(long)] + pub local_time: bool, + + /// Plain text output (no ANSI styling) + #[arg(long)] + pub plain: bool, + + /// Gateway URL (default: http://{GATEWAY_HOST}:{GATEWAY_PORT}) + #[arg(long)] + pub url: Option, + + /// Gateway auth token (reads GATEWAY_AUTH_TOKEN env if not set) + #[arg(long)] + pub token: Option, + + /// Connection timeout in milliseconds (default: 5000) + #[arg(long, default_value = "5000")] + pub timeout: u64, + + /// Get or set runtime log level. Without a value, shows current level. + /// With a value (trace|debug|info|warn|error), sets the level. + #[arg(long, num_args = 0..=1, default_missing_value = "")] + pub level: Option, +} + +/// Resolved gateway connection parameters. +struct GatewayParams { + base_url: String, + token: String, +} + +/// Run the logs CLI command. +pub async fn run_logs_command(cmd: LogsCommand, config_path: Option<&Path>) -> anyhow::Result<()> { + // --level takes priority: it's a control-plane operation, not log viewing. + if let Some(level_arg) = &cmd.level { + let params = resolve_gateway_params(&cmd, config_path).await?; + if level_arg.is_empty() { + return cmd_get_level(&cmd, ¶ms).await; + } else { + return cmd_set_level(&cmd, level_arg, ¶ms).await; + } + } + + if cmd.follow { + let params = resolve_gateway_params(&cmd, config_path).await?; + cmd_follow(&cmd, ¶ms).await + } else { + cmd_show(&cmd) + } +} + +// ── Show log file ──────────────────────────────────────────────────────── + +/// Read the last N lines from `~/.ironclaw/gateway.log`. +/// +/// Uses a reverse-scan strategy: seeks to the end of the file and reads +/// backwards in chunks to find the last `limit` newlines, so memory usage +/// is proportional to the output size, not the file size. +fn cmd_show(cmd: &LogsCommand) -> anyhow::Result<()> { + let log_path = crate::bootstrap::ironclaw_base_dir().join("gateway.log"); + if !log_path.exists() { + anyhow::bail!( + "No gateway log file found at {}.\n\ + The log file is created when the gateway runs in background mode \ + (e.g. `ironclaw gateway start`).", + log_path.display() + ); + } + + let lines = tail_file(&log_path, cmd.limit)?; + + if lines.is_empty() { + println!("(log file is empty)"); + return Ok(()); + } + + if cmd.json { + for line in &lines { + let obj = serde_json::json!({ "line": line }); + println!("{}", obj); + } + } else { + for line in &lines { + println!("{}", line); + } + } + + Ok(()) +} + +/// Read the last `n` lines from a file by scanning backwards from EOF. +/// +/// Reads in 8 KiB chunks from the end, counting newlines until enough +/// are found or the beginning of the file is reached. +fn tail_file(path: &Path, n: usize) -> anyhow::Result> { + let mut file = std::fs::File::open(path) + .map_err(|e| anyhow::anyhow!("Failed to open {}: {}", path.display(), e))?; + + let file_len = file + .seek(SeekFrom::End(0)) + .map_err(|e| anyhow::anyhow!("Failed to seek {}: {}", path.display(), e))?; + + if file_len == 0 { + return Ok(Vec::new()); + } + + // Read backwards in chunks to find enough newlines. + const CHUNK_SIZE: u64 = 8192; + let mut tail_bytes = Vec::new(); + let mut newline_count = 0; + let mut remaining = file_len; + + while remaining > 0 && newline_count <= n { + let read_size = std::cmp::min(CHUNK_SIZE, remaining); + remaining -= read_size; + + file.seek(SeekFrom::Start(remaining)) + .map_err(|e| anyhow::anyhow!("Seek failed: {e}"))?; + + let mut chunk = vec![0u8; read_size as usize]; + std::io::Read::read_exact(&mut file, &mut chunk) + .map_err(|e| anyhow::anyhow!("Read failed: {e}"))?; + + // Count newlines in this chunk (backwards). + for &byte in chunk.iter().rev() { + if byte == b'\n' { + newline_count += 1; + } + } + + // Prepend chunk to collected bytes. + chunk.append(&mut tail_bytes); + tail_bytes = chunk; + } + + // Convert to string and take last N lines. + let text = String::from_utf8_lossy(&tail_bytes); + let all_lines: Vec<&str> = text.lines().collect(); + let start = all_lines.len().saturating_sub(n); + + Ok(all_lines[start..].iter().map(|s| s.to_string()).collect()) +} + +// ── Follow (live SSE stream) ───────────────────────────────────────────── + +/// Connect to the gateway's `/api/logs/events` SSE endpoint and stream logs. +async fn cmd_follow(cmd: &LogsCommand, params: &GatewayParams) -> anyhow::Result<()> { + let timeout_dur = std::time::Duration::from_millis(cmd.timeout); + + let client = reqwest::Client::builder() + .connect_timeout(timeout_dur) + .build() + .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {e}"))?; + + let url = format!("{}/api/logs/events", params.base_url); + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {}", params.token)) + .header("Accept", "text/event-stream") + // No per-request timeout: SSE streams are long-lived. + .timeout(std::time::Duration::from_secs(u64::MAX / 2)) + .send() + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to connect to gateway at {url}: {e}\n\ + Is the gateway running? Try `ironclaw gateway status`." + ) + })?; + + if !resp.status().is_success() { + anyhow::bail!( + "Gateway returned HTTP {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ); + } + + eprintln!("Connected to {} — streaming logs (Ctrl-C to stop)", url); + + // Parse SSE stream line by line. + let mut bytes_stream = resp.bytes_stream(); + let mut buffer = String::new(); + let mut lines_shown: usize = 0; + + use futures::StreamExt; + while let Some(chunk) = bytes_stream.next().await { + let chunk = chunk.map_err(|e| anyhow::anyhow!("Stream error: {e}"))?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + // Process complete lines from the buffer. + while let Some(newline_pos) = buffer.find('\n') { + let line = buffer[..newline_pos].to_string(); + buffer = buffer[newline_pos + 1..].to_string(); + + // SSE format: "data: {...}" lines carry the payload. + if let Some(data) = line.strip_prefix("data: ") + && let Ok(entry) = serde_json::from_str::(data) + { + print_log_entry(&entry, cmd); + lines_shown += 1; + } + // Skip "event:", "id:", "retry:", and empty keepalive lines. + } + } + + if lines_shown == 0 { + eprintln!("(no log entries received)"); + } + + Ok(()) +} + +// ── Log level get/set ──────────────────────────────────────────────────── + +/// GET /api/logs/level — show the current log level. +async fn cmd_get_level(cmd: &LogsCommand, params: &GatewayParams) -> anyhow::Result<()> { + let timeout_dur = std::time::Duration::from_millis(cmd.timeout); + + let client = reqwest::Client::builder() + .timeout(timeout_dur) + .build() + .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {e}"))?; + + let url = format!("{}/api/logs/level", params.base_url); + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {}", params.token)) + .send() + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to connect to gateway at {url}: {e}\n\ + Is the gateway running? Try `ironclaw gateway status`." + ) + })?; + + if !resp.status().is_success() { + anyhow::bail!( + "Gateway returned HTTP {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ); + } + + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| anyhow::anyhow!("Invalid response: {e}"))?; + + if cmd.json { + println!( + "{}", + serde_json::to_string_pretty(&body).unwrap_or_default() + ); + } else { + let level = body + .get("level") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + println!("Current log level: {}", level); + } + + Ok(()) +} + +/// PUT /api/logs/level — change the runtime log level. +async fn cmd_set_level( + cmd: &LogsCommand, + level: &str, + params: &GatewayParams, +) -> anyhow::Result<()> { + const VALID: &[&str] = &["trace", "debug", "info", "warn", "error"]; + let level_lower = level.to_lowercase(); + if !VALID.contains(&level_lower.as_str()) { + anyhow::bail!( + "Invalid log level '{}'. Must be one of: {}", + level, + VALID.join(", ") + ); + } + + let timeout_dur = std::time::Duration::from_millis(cmd.timeout); + + let client = reqwest::Client::builder() + .timeout(timeout_dur) + .build() + .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {e}"))?; + + let url = format!("{}/api/logs/level", params.base_url); + let resp = client + .put(&url) + .header("Authorization", format!("Bearer {}", params.token)) + .json(&serde_json::json!({ "level": level_lower })) + .send() + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to connect to gateway at {url}: {e}\n\ + Is the gateway running? Try `ironclaw gateway status`." + ) + })?; + + if !resp.status().is_success() { + anyhow::bail!( + "Gateway returned HTTP {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ); + } + + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| anyhow::anyhow!("Invalid response: {e}"))?; + + if cmd.json { + println!( + "{}", + serde_json::to_string_pretty(&body).unwrap_or_default() + ); + } else { + let new_level = body + .get("level") + .and_then(|v| v.as_str()) + .unwrap_or(&level_lower); + println!("Log level set to: {}", new_level); + } + + Ok(()) +} + +// ── Helpers ────────────────────────────────────────────────────────────── + +/// Resolve gateway connection params from CLI flags, config file, or env. +/// +/// Priority: --url/--token flags > config TOML > env vars > defaults. +async fn resolve_gateway_params( + cmd: &LogsCommand, + config_path: Option<&Path>, +) -> anyhow::Result { + // Load gateway config. Errors propagate when --config is explicit. + let gw_config = load_gateway_config(config_path).await?; + + // URL: --url flag > config TOML > env vars > defaults. + let base_url = if let Some(url) = &cmd.url { + url.trim_end_matches('/').to_string() + } else if let Some(cfg) = &gw_config { + format!("http://{}:{}", cfg.host, cfg.port) + } else { + let host = std::env::var("GATEWAY_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); + let port: u16 = std::env::var("GATEWAY_PORT") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(3000); + format!("http://{}:{}", host, port) + }; + + // Token: --token flag > config TOML > env var. + let token = if let Some(token) = &cmd.token { + token.clone() + } else if let Some(t) = gw_config.as_ref().and_then(|c| c.auth_token.clone()) { + t + } else { + std::env::var("GATEWAY_AUTH_TOKEN").map_err(|_| { + anyhow::anyhow!( + "No auth token provided. Use --token or set GATEWAY_AUTH_TOKEN.\n\ + The token is printed when the gateway starts." + ) + })? + }; + + Ok(GatewayParams { base_url, token }) +} + +/// Try to load gateway config from the TOML config file. +/// +/// If `config_path` was explicitly provided (via `--config`), errors are +/// propagated — the user asked for a specific file and deserves a clear +/// failure when it is missing, unreadable, or malformed. When no path +/// was given we fall back to env-only resolution and silently return +/// `None` on failure so that `ironclaw logs` works without any config. +async fn load_gateway_config( + config_path: Option<&Path>, +) -> anyhow::Result> { + if config_path.is_some() { + // Explicit --config: propagate errors. + let config = crate::config::Config::from_env_with_toml(config_path) + .await + .map_err(|e| anyhow::anyhow!("{e:#}"))?; + Ok(config.channels.gateway) + } else { + // No explicit config: best-effort, swallow errors. + let config = crate::config::Config::from_env_with_toml(None).await.ok(); + Ok(config.and_then(|c| c.channels.gateway)) + } +} + +/// Print a single log entry to stdout. +fn print_log_entry(entry: &serde_json::Value, cmd: &LogsCommand) { + if cmd.json { + println!("{}", serde_json::to_string(entry).unwrap_or_default()); + return; + } + + let level = entry.get("level").and_then(|v| v.as_str()).unwrap_or("?"); + let target = entry.get("target").and_then(|v| v.as_str()).unwrap_or(""); + let message = entry.get("message").and_then(|v| v.as_str()).unwrap_or(""); + let timestamp = entry + .get("timestamp") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let display_ts = if cmd.local_time { + convert_to_local_time(timestamp) + } else { + timestamp.to_string() + }; + + if cmd.plain { + println!("{} {} [{}] {}", display_ts, level, target, message); + } else { + let level_colored = colorize_level(level); + println!("{} {} [{}] {}", display_ts, level_colored, target, message); + } +} + +/// Convert an RFC 3339 timestamp to local time display. +fn convert_to_local_time(ts: &str) -> String { + chrono::DateTime::parse_from_rfc3339(ts) + .map(|dt| { + dt.with_timezone(&chrono::Local) + .format("%Y-%m-%dT%H:%M:%S%.3f") + .to_string() + }) + .unwrap_or_else(|_| ts.to_string()) +} + +/// Apply ANSI color to log level for terminal display. +fn colorize_level(level: &str) -> String { + match level { + "ERROR" => format!("\x1b[31m{}\x1b[0m", level), // red + "WARN" => format!("\x1b[33m{}\x1b[0m", level), // yellow + "INFO" => format!("\x1b[32m{}\x1b[0m", level), // green + "DEBUG" => format!("\x1b[36m{}\x1b[0m", level), // cyan + "TRACE" => format!("\x1b[90m{}\x1b[0m", level), // gray + _ => level.to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_colorize_level() { + assert!(colorize_level("ERROR").contains("\x1b[31m")); + assert!(colorize_level("WARN").contains("\x1b[33m")); + assert!(colorize_level("INFO").contains("\x1b[32m")); + assert!(colorize_level("DEBUG").contains("\x1b[36m")); + assert!(colorize_level("TRACE").contains("\x1b[90m")); + assert_eq!(colorize_level("UNKNOWN"), "UNKNOWN"); + } + + #[test] + fn test_convert_to_local_time_valid() { + let ts = "2024-01-15T10:30:00.000Z"; + let result = convert_to_local_time(ts); + assert!(result.contains("2024-01-15")); + } + + #[test] + fn test_convert_to_local_time_invalid() { + let ts = "not-a-timestamp"; + assert_eq!(convert_to_local_time(ts), "not-a-timestamp"); + } + + #[test] + fn test_print_log_entry_json() { + let entry = serde_json::json!({ + "level": "INFO", + "target": "ironclaw::agent", + "message": "test message", + "timestamp": "2024-01-15T10:30:00.000Z" + }); + let cmd = LogsCommand { + follow: false, + limit: 200, + json: true, + local_time: false, + plain: false, + url: None, + token: None, + timeout: 5000, + level: None, + }; + // Should not panic + print_log_entry(&entry, &cmd); + } + + #[test] + fn test_tail_file_small() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.log"); + std::fs::write(&path, "line1\nline2\nline3\nline4\nline5\n").unwrap(); + + let result = tail_file(&path, 3).unwrap(); + assert_eq!(result, vec!["line3", "line4", "line5"]); + } + + #[test] + fn test_tail_file_fewer_lines_than_limit() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.log"); + std::fs::write(&path, "a\nb\n").unwrap(); + + let result = tail_file(&path, 200).unwrap(); + assert_eq!(result, vec!["a", "b"]); + } + + #[test] + fn test_tail_file_empty() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.log"); + std::fs::write(&path, "").unwrap(); + + let result = tail_file(&path, 10).unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn test_tail_file_large() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("big.log"); + // Write 10000 lines to test chunked reading. + let content: String = (0..10000).map(|i| format!("line {}\n", i)).collect(); + std::fs::write(&path, &content).unwrap(); + + let result = tail_file(&path, 5).unwrap(); + assert_eq!(result.len(), 5); + assert_eq!(result[0], "line 9995"); + assert_eq!(result[4], "line 9999"); + } + + #[test] + fn test_tail_file_no_trailing_newline() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.log"); + std::fs::write(&path, "line1\nline2\nline3").unwrap(); + + let result = tail_file(&path, 2).unwrap(); + assert_eq!(result, vec!["line2", "line3"]); + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 652cac01ea..cf3c793e81 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -11,6 +11,7 @@ //! - Managing OS service (`service install`, `service start`, `service stop`) //! - Listing configured channels (`channels list`) //! - Active health diagnostics (`doctor`) +//! - Viewing gateway logs (`logs`) //! - Checking system health (`status`) mod channels; @@ -19,6 +20,7 @@ mod config; mod doctor; #[cfg(feature = "import")] pub mod import; +mod logs; mod mcp; pub mod memory; pub mod oauth_defaults; @@ -36,6 +38,7 @@ pub use config::{ConfigCommand, run_config_command}; pub use doctor::run_doctor_command; #[cfg(feature = "import")] pub use import::{ImportCommand, run_import_command}; +pub use logs::{LogsCommand, run_logs_command}; pub use mcp::{McpCommand, run_mcp_command}; pub use memory::MemoryCommand; pub use memory::run_memory_command_with_db; @@ -206,6 +209,13 @@ pub enum Command { )] Doctor, + /// View and manage gateway logs + #[command( + about = "View and manage gateway logs", + long_about = "Tail gateway logs, stream live output, or adjust log level.\nExamples:\n ironclaw logs # Show last 200 lines from gateway.log\n ironclaw logs --follow # Stream live logs via SSE\n ironclaw logs --level # Show current log level\n ironclaw logs --level debug # Set log level to debug" + )] + Logs(LogsCommand), + /// Show system health and diagnostics #[command( about = "Show system status", diff --git a/src/cli/snapshots/ironclaw__cli__tests__help_output.snap b/src/cli/snapshots/ironclaw__cli__tests__help_output.snap new file mode 100644 index 0000000000..a554acaeba --- /dev/null +++ b/src/cli/snapshots/ironclaw__cli__tests__help_output.snap @@ -0,0 +1,36 @@ +--- +source: src/cli/mod.rs +expression: help +--- +Secure personal AI assistant that protects your data and expands its capabilities + +Usage: ironclaw [OPTIONS] [COMMAND] + +Commands: + run Run the AI agent + onboard Run interactive setup wizard + config Manage app configs + tool Manage WASM tools + registry Browse/install extensions + channels Manage channels + routines Manage routines + mcp Manage MCP servers + memory Manage workspace memory + pairing Manage DM pairing + service Manage OS service + skills Manage skills + doctor Run diagnostics + logs View and manage gateway logs + status Show system status + completion Generate completions + import Import from other AI systems + help Print this message or the help of the given subcommand(s) + +Options: + --cli-only Run in interactive CLI mode only (disable other channels) + --no-db Skip database connection (for testing) + -m, --message Single message mode - send one message and exit + -c, --config Configuration file path (optional, uses env vars by default) + --no-onboard Skip first-run onboarding check + -h, --help Print help (see more with '--help') + -V, --version Print version diff --git a/src/cli/snapshots/ironclaw__cli__tests__help_output_without_import.snap b/src/cli/snapshots/ironclaw__cli__tests__help_output_without_import.snap index c7d8db1331..3f3cf4fc0b 100644 --- a/src/cli/snapshots/ironclaw__cli__tests__help_output_without_import.snap +++ b/src/cli/snapshots/ironclaw__cli__tests__help_output_without_import.snap @@ -20,6 +20,7 @@ Commands: service Manage OS service skills Manage skills doctor Run diagnostics + logs View and manage gateway logs status Show system status completion Generate completions help Print this message or the help of the given subcommand(s) diff --git a/src/cli/snapshots/ironclaw__cli__tests__long_help_output.snap b/src/cli/snapshots/ironclaw__cli__tests__long_help_output.snap new file mode 100644 index 0000000000..99b3ef53bb --- /dev/null +++ b/src/cli/snapshots/ironclaw__cli__tests__long_help_output.snap @@ -0,0 +1,52 @@ +--- +source: src/cli/mod.rs +expression: help +--- +IronClaw is a secure AI assistant. Use 'ironclaw --help' for details. +Examples: + ironclaw run # Start the agent + ironclaw config list # List configs + +Usage: ironclaw [OPTIONS] [COMMAND] + +Commands: + run Run the AI agent + onboard Run interactive setup wizard + config Manage app configs + tool Manage WASM tools + registry Browse/install extensions + channels Manage channels + routines Manage routines + mcp Manage MCP servers + memory Manage workspace memory + pairing Manage DM pairing + service Manage OS service + skills Manage skills + doctor Run diagnostics + logs View and manage gateway logs + status Show system status + completion Generate completions + import Import from other AI systems + help Print this message or the help of the given subcommand(s) + +Options: + --cli-only + Run in interactive CLI mode only (disable other channels) + + --no-db + Skip database connection (for testing) + + -m, --message + Single message mode - send one message and exit + + -c, --config + Configuration file path (optional, uses env vars by default) + + --no-onboard + Skip first-run onboarding check + + -h, --help + Print help (see a summary with '-h') + + -V, --version + Print version diff --git a/src/cli/snapshots/ironclaw__cli__tests__long_help_output_without_import.snap b/src/cli/snapshots/ironclaw__cli__tests__long_help_output_without_import.snap index fb4ad2313e..aa7ae8b0fe 100644 --- a/src/cli/snapshots/ironclaw__cli__tests__long_help_output_without_import.snap +++ b/src/cli/snapshots/ironclaw__cli__tests__long_help_output_without_import.snap @@ -23,6 +23,7 @@ Commands: service Manage OS service skills Manage skills doctor Run diagnostics + logs View and manage gateway logs status Show system status completion Generate completions help Print this message or the help of the given subcommand(s) diff --git a/src/config/channels.rs b/src/config/channels.rs index 90635c22f5..981b017008 100644 --- a/src/config/channels.rs +++ b/src/config/channels.rs @@ -91,11 +91,20 @@ pub struct SignalConfig { } impl ChannelsConfig { + /// Resolve channels config following `env > settings > default` for every field. pub(crate) fn resolve(settings: &Settings) -> Result { - let http = if optional_env("HTTP_PORT")?.is_some() || optional_env("HTTP_HOST")?.is_some() { + let cs = &settings.channels; + + // --- HTTP webhook --- + // HTTP is enabled when env vars are set OR settings has it enabled. + let http_enabled_by_env = + optional_env("HTTP_PORT")?.is_some() || optional_env("HTTP_HOST")?.is_some(); + let http = if http_enabled_by_env || cs.http_enabled { Some(HttpConfig { - host: optional_env("HTTP_HOST")?.unwrap_or_else(|| "0.0.0.0".to_string()), - port: parse_optional_env("HTTP_PORT", 8080)?, + host: optional_env("HTTP_HOST")? + .or_else(|| cs.http_host.clone()) + .unwrap_or_else(|| "0.0.0.0".to_string()), + port: parse_optional_env("HTTP_PORT", cs.http_port.unwrap_or(8080))?, webhook_secret: optional_env("HTTP_WEBHOOK_SECRET")?.map(SecretString::from), user_id: optional_env("HTTP_USER_ID")?.unwrap_or_else(|| "http".to_string()), }) @@ -103,42 +112,58 @@ impl ChannelsConfig { None }; - let gateway_enabled = parse_bool_env("GATEWAY_ENABLED", true)?; + // --- Web gateway --- + let gateway_enabled = parse_bool_env("GATEWAY_ENABLED", cs.gateway_enabled)?; let gateway = if gateway_enabled { Some(GatewayConfig { - host: optional_env("GATEWAY_HOST")?.unwrap_or_else(|| "127.0.0.1".to_string()), - port: parse_optional_env("GATEWAY_PORT", 3000)?, - auth_token: optional_env("GATEWAY_AUTH_TOKEN")?, - user_id: optional_env("GATEWAY_USER_ID")?.unwrap_or_else(|| "default".to_string()), + host: optional_env("GATEWAY_HOST")? + .or_else(|| cs.gateway_host.clone()) + .unwrap_or_else(|| "127.0.0.1".to_string()), + port: parse_optional_env( + "GATEWAY_PORT", + cs.gateway_port.unwrap_or(DEFAULT_GATEWAY_PORT), + )?, + auth_token: optional_env("GATEWAY_AUTH_TOKEN")? + .or_else(|| cs.gateway_auth_token.clone()), + user_id: optional_env("GATEWAY_USER_ID")? + .or_else(|| cs.gateway_user_id.clone()) + .unwrap_or_else(|| "default".to_string()), }) } else { None }; - let signal = if let Some(http_url) = optional_env("SIGNAL_HTTP_URL")? { - let account = optional_env("SIGNAL_ACCOUNT")?.ok_or(ConfigError::InvalidValue { - key: "SIGNAL_ACCOUNT".to_string(), - message: "SIGNAL_ACCOUNT is required when SIGNAL_HTTP_URL is set".to_string(), - })?; - let allow_from = match std::env::var_os("SIGNAL_ALLOW_FROM") { + // --- Signal --- + let signal_url = optional_env("SIGNAL_HTTP_URL")?.or_else(|| cs.signal_http_url.clone()); + let signal = if let Some(http_url) = signal_url { + let account = optional_env("SIGNAL_ACCOUNT")? + .or_else(|| cs.signal_account.clone()) + .ok_or(ConfigError::InvalidValue { + key: "SIGNAL_ACCOUNT".to_string(), + message: "SIGNAL_ACCOUNT is required when Signal is enabled".to_string(), + })?; + let allow_from_str = + optional_env("SIGNAL_ALLOW_FROM")?.or_else(|| cs.signal_allow_from.clone()); + let allow_from = match allow_from_str { None => vec![account.clone()], - Some(val) => { - let s = val.to_string_lossy(); - s.split(',') - .map(|e| e.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect() - } + Some(s) => s + .split(',') + .map(|e| e.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(), }; - let dm_policy = - optional_env("SIGNAL_DM_POLICY")?.unwrap_or_else(|| "pairing".to_string()); - let group_policy = - optional_env("SIGNAL_GROUP_POLICY")?.unwrap_or_else(|| "allowlist".to_string()); + let dm_policy = optional_env("SIGNAL_DM_POLICY")? + .or_else(|| cs.signal_dm_policy.clone()) + .unwrap_or_else(|| "pairing".to_string()); + let group_policy = optional_env("SIGNAL_GROUP_POLICY")? + .or_else(|| cs.signal_group_policy.clone()) + .unwrap_or_else(|| "allowlist".to_string()); Some(SignalConfig { http_url, account, allow_from, allow_from_groups: optional_env("SIGNAL_ALLOW_FROM_GROUPS")? + .or_else(|| cs.signal_allow_from_groups.clone()) .map(|s| { s.split(',') .map(|e| e.trim().to_string()) @@ -149,6 +174,7 @@ impl ChannelsConfig { dm_policy, group_policy, group_allow_from: optional_env("SIGNAL_GROUP_ALLOW_FROM")? + .or_else(|| cs.signal_group_allow_from.clone()) .map(|s| { s.split(',') .map(|e| e.trim().to_string()) @@ -167,9 +193,17 @@ impl ChannelsConfig { None }; - let cli_enabled = optional_env("CLI_ENABLED")? - .map(|s| s.to_lowercase() != "false" && s != "0") - .unwrap_or(true); + // --- CLI --- + let cli_enabled = parse_bool_env("CLI_ENABLED", cs.cli_enabled)?; + + // --- WASM channels --- + let wasm_channels_dir = optional_env("WASM_CHANNELS_DIR")? + .map(PathBuf::from) + .or_else(|| cs.wasm_channels_dir.clone()) + .unwrap_or_else(default_channels_dir); + + let wasm_channels_enabled = + parse_bool_env("WASM_CHANNELS_ENABLED", cs.wasm_channels_enabled)?; Ok(Self { cli: CliConfig { @@ -178,12 +212,10 @@ impl ChannelsConfig { http, gateway, signal, - wasm_channels_dir: optional_env("WASM_CHANNELS_DIR")? - .map(PathBuf::from) - .unwrap_or_else(default_channels_dir), - wasm_channels_enabled: parse_bool_env("WASM_CHANNELS_ENABLED", true)?, + wasm_channels_dir, + wasm_channels_enabled, wasm_channel_owner_ids: { - let mut ids = settings.channels.wasm_channel_owner_ids.clone(); + let mut ids = cs.wasm_channel_owner_ids.clone(); // Backwards compat: TELEGRAM_OWNER_ID env var if let Some(id_str) = optional_env("TELEGRAM_OWNER_ID")? { let id: i64 = id_str.parse().map_err(|e: std::num::ParseIntError| { @@ -200,6 +232,10 @@ impl ChannelsConfig { } } +/// Default gateway port — used both in `resolve()` and as the fallback in +/// other modules that need to construct a gateway URL. +pub const DEFAULT_GATEWAY_PORT: u16 = 3000; + /// Get the default channels directory (~/.ironclaw/channels/). fn default_channels_dir() -> PathBuf { ironclaw_base_dir().join("channels") @@ -362,4 +398,244 @@ mod tests { "expected path ending in 'channels', got: {dir:?}" ); } + + #[test] + fn default_gateway_port_constant() { + assert_eq!(DEFAULT_GATEWAY_PORT, 3000); + } + + /// With default settings and no env vars, gateway should use defaults. + #[test] + fn resolve_gateway_defaults_from_settings() { + let _lock = crate::config::helpers::ENV_MUTEX.lock(); + // Clear env vars that would interfere + unsafe { + std::env::remove_var("GATEWAY_ENABLED"); + std::env::remove_var("GATEWAY_HOST"); + std::env::remove_var("GATEWAY_PORT"); + std::env::remove_var("GATEWAY_AUTH_TOKEN"); + std::env::remove_var("GATEWAY_USER_ID"); + std::env::remove_var("HTTP_PORT"); + std::env::remove_var("HTTP_HOST"); + std::env::remove_var("SIGNAL_HTTP_URL"); + std::env::remove_var("CLI_ENABLED"); + std::env::remove_var("WASM_CHANNELS_DIR"); + std::env::remove_var("WASM_CHANNELS_ENABLED"); + std::env::remove_var("TELEGRAM_OWNER_ID"); + } + + let settings = crate::settings::Settings::default(); + let cfg = ChannelsConfig::resolve(&settings).unwrap(); + + let gw = cfg.gateway.expect("gateway should be enabled by default"); + assert_eq!(gw.host, "127.0.0.1"); + assert_eq!(gw.port, DEFAULT_GATEWAY_PORT); + assert!(gw.auth_token.is_none()); + assert_eq!(gw.user_id, "default"); + } + + /// Settings values should be used when no env vars are set. + #[test] + fn resolve_gateway_from_settings() { + let _lock = crate::config::helpers::ENV_MUTEX.lock(); + unsafe { + std::env::remove_var("GATEWAY_ENABLED"); + std::env::remove_var("GATEWAY_HOST"); + std::env::remove_var("GATEWAY_PORT"); + std::env::remove_var("GATEWAY_AUTH_TOKEN"); + std::env::remove_var("GATEWAY_USER_ID"); + std::env::remove_var("HTTP_PORT"); + std::env::remove_var("HTTP_HOST"); + std::env::remove_var("SIGNAL_HTTP_URL"); + std::env::remove_var("CLI_ENABLED"); + std::env::remove_var("WASM_CHANNELS_DIR"); + std::env::remove_var("WASM_CHANNELS_ENABLED"); + std::env::remove_var("TELEGRAM_OWNER_ID"); + } + + let mut settings = crate::settings::Settings::default(); + settings.channels.gateway_port = Some(4000); + settings.channels.gateway_host = Some("0.0.0.0".to_string()); + settings.channels.gateway_auth_token = Some("db-token-123".to_string()); + settings.channels.gateway_user_id = Some("myuser".to_string()); + + let cfg = ChannelsConfig::resolve(&settings).unwrap(); + let gw = cfg.gateway.expect("gateway should be enabled"); + assert_eq!(gw.port, 4000); + assert_eq!(gw.host, "0.0.0.0"); + assert_eq!(gw.auth_token.as_deref(), Some("db-token-123")); + assert_eq!(gw.user_id, "myuser"); + } + + /// Env vars should override settings values. + #[test] + fn resolve_env_overrides_settings() { + let _lock = crate::config::helpers::ENV_MUTEX.lock(); + unsafe { + std::env::set_var("GATEWAY_PORT", "5000"); + std::env::set_var("GATEWAY_HOST", "10.0.0.1"); + std::env::set_var("GATEWAY_AUTH_TOKEN", "env-token"); + std::env::remove_var("GATEWAY_ENABLED"); + std::env::remove_var("GATEWAY_USER_ID"); + std::env::remove_var("HTTP_PORT"); + std::env::remove_var("HTTP_HOST"); + std::env::remove_var("SIGNAL_HTTP_URL"); + std::env::remove_var("CLI_ENABLED"); + std::env::remove_var("WASM_CHANNELS_DIR"); + std::env::remove_var("WASM_CHANNELS_ENABLED"); + std::env::remove_var("TELEGRAM_OWNER_ID"); + } + + let mut settings = crate::settings::Settings::default(); + settings.channels.gateway_port = Some(4000); + settings.channels.gateway_host = Some("0.0.0.0".to_string()); + settings.channels.gateway_auth_token = Some("db-token".to_string()); + + let cfg = ChannelsConfig::resolve(&settings).unwrap(); + let gw = cfg.gateway.expect("gateway should be enabled"); + assert_eq!(gw.port, 5000, "env should override settings"); + assert_eq!(gw.host, "10.0.0.1", "env should override settings"); + assert_eq!( + gw.auth_token.as_deref(), + Some("env-token"), + "env should override settings" + ); + + // Cleanup + unsafe { + std::env::remove_var("GATEWAY_PORT"); + std::env::remove_var("GATEWAY_HOST"); + std::env::remove_var("GATEWAY_AUTH_TOKEN"); + } + } + + /// CLI enabled should fall back to settings. + #[test] + fn resolve_cli_enabled_from_settings() { + let _lock = crate::config::helpers::ENV_MUTEX.lock(); + unsafe { + std::env::remove_var("CLI_ENABLED"); + std::env::remove_var("GATEWAY_ENABLED"); + std::env::remove_var("GATEWAY_HOST"); + std::env::remove_var("GATEWAY_PORT"); + std::env::remove_var("GATEWAY_AUTH_TOKEN"); + std::env::remove_var("GATEWAY_USER_ID"); + std::env::remove_var("HTTP_PORT"); + std::env::remove_var("HTTP_HOST"); + std::env::remove_var("SIGNAL_HTTP_URL"); + std::env::remove_var("WASM_CHANNELS_DIR"); + std::env::remove_var("WASM_CHANNELS_ENABLED"); + std::env::remove_var("TELEGRAM_OWNER_ID"); + } + + let mut settings = crate::settings::Settings::default(); + settings.channels.cli_enabled = false; + + let cfg = ChannelsConfig::resolve(&settings).unwrap(); + assert!(!cfg.cli.enabled, "settings should disable CLI"); + } + + /// HTTP channel should activate when settings has it enabled. + #[test] + fn resolve_http_from_settings() { + let _lock = crate::config::helpers::ENV_MUTEX.lock(); + unsafe { + std::env::remove_var("HTTP_PORT"); + std::env::remove_var("HTTP_HOST"); + std::env::remove_var("HTTP_WEBHOOK_SECRET"); + std::env::remove_var("HTTP_USER_ID"); + std::env::remove_var("GATEWAY_ENABLED"); + std::env::remove_var("GATEWAY_HOST"); + std::env::remove_var("GATEWAY_PORT"); + std::env::remove_var("GATEWAY_AUTH_TOKEN"); + std::env::remove_var("GATEWAY_USER_ID"); + std::env::remove_var("SIGNAL_HTTP_URL"); + std::env::remove_var("CLI_ENABLED"); + std::env::remove_var("WASM_CHANNELS_DIR"); + std::env::remove_var("WASM_CHANNELS_ENABLED"); + std::env::remove_var("TELEGRAM_OWNER_ID"); + } + + let mut settings = crate::settings::Settings::default(); + settings.channels.http_enabled = true; + settings.channels.http_port = Some(9090); + settings.channels.http_host = Some("10.0.0.1".to_string()); + + let cfg = ChannelsConfig::resolve(&settings).unwrap(); + let http = cfg.http.expect("HTTP should be enabled from settings"); + assert_eq!(http.port, 9090); + assert_eq!(http.host, "10.0.0.1"); + } + + /// Settings round-trip through DB map for new gateway fields. + #[test] + fn settings_gateway_fields_db_roundtrip() { + let mut settings = crate::settings::Settings::default(); + settings.channels.gateway_port = Some(4000); + settings.channels.gateway_host = Some("0.0.0.0".to_string()); + settings.channels.gateway_auth_token = Some("tok-abc".to_string()); + settings.channels.gateway_user_id = Some("myuser".to_string()); + settings.channels.cli_enabled = false; + + let map = settings.to_db_map(); + let restored = crate::settings::Settings::from_db_map(&map); + + assert_eq!(restored.channels.gateway_port, Some(4000)); + assert_eq!(restored.channels.gateway_host.as_deref(), Some("0.0.0.0")); + assert_eq!( + restored.channels.gateway_auth_token.as_deref(), + Some("tok-abc") + ); + assert_eq!(restored.channels.gateway_user_id.as_deref(), Some("myuser")); + assert!(!restored.channels.cli_enabled); + } + + /// Invalid boolean env values must produce errors, not silently degrade. + #[test] + fn resolve_rejects_invalid_bool_env() { + let _lock = crate::config::helpers::ENV_MUTEX.lock(); + let settings = crate::settings::Settings::default(); + + // GATEWAY_ENABLED=maybe should error + unsafe { + std::env::set_var("GATEWAY_ENABLED", "maybe"); + std::env::remove_var("HTTP_PORT"); + std::env::remove_var("HTTP_HOST"); + std::env::remove_var("SIGNAL_HTTP_URL"); + std::env::remove_var("CLI_ENABLED"); + std::env::remove_var("WASM_CHANNELS_ENABLED"); + std::env::remove_var("GATEWAY_PORT"); + std::env::remove_var("GATEWAY_HOST"); + std::env::remove_var("GATEWAY_AUTH_TOKEN"); + std::env::remove_var("GATEWAY_USER_ID"); + std::env::remove_var("WASM_CHANNELS_DIR"); + std::env::remove_var("TELEGRAM_OWNER_ID"); + } + let result = ChannelsConfig::resolve(&settings); + assert!(result.is_err(), "GATEWAY_ENABLED=maybe should be rejected"); + + // CLI_ENABLED=on should error + unsafe { + std::env::remove_var("GATEWAY_ENABLED"); + std::env::set_var("CLI_ENABLED", "on"); + } + let result = ChannelsConfig::resolve(&settings); + assert!(result.is_err(), "CLI_ENABLED=on should be rejected"); + + // WASM_CHANNELS_ENABLED=yes should error + unsafe { + std::env::remove_var("CLI_ENABLED"); + std::env::set_var("WASM_CHANNELS_ENABLED", "yes"); + } + let result = ChannelsConfig::resolve(&settings); + assert!( + result.is_err(), + "WASM_CHANNELS_ENABLED=yes should be rejected" + ); + + // Cleanup + unsafe { + std::env::remove_var("WASM_CHANNELS_ENABLED"); + } + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 34c34423ac..0ce8dfecc5 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -34,7 +34,9 @@ use crate::settings::Settings; // Re-export all public types so `crate::config::FooConfig` continues to work. pub use self::agent::AgentConfig; pub use self::builder::BuilderModeConfig; -pub use self::channels::{ChannelsConfig, CliConfig, GatewayConfig, HttpConfig, SignalConfig}; +pub use self::channels::{ + ChannelsConfig, CliConfig, DEFAULT_GATEWAY_PORT, GatewayConfig, HttpConfig, SignalConfig, +}; pub use self::database::{DatabaseBackend, DatabaseConfig, SslMode, default_libsql_path}; pub use self::embeddings::EmbeddingsConfig; pub use self::heartbeat::HeartbeatConfig; diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index 05b07555ae..e057e2acc1 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -2864,9 +2864,16 @@ impl ExtensionManager { // Try to list and create tools. // A 401/auth error means the server requires OAuth — surface as // AuthRequired so the activate handler triggers the OAuth flow. + // Some servers (e.g. GitHub MCP) return 400 with "Authorization header + // is badly formatted" instead of 401 when auth is missing or invalid. let mcp_tools = client.list_tools().await.map_err(|e| { let msg = e.to_string(); - if msg.contains("requires authentication") || msg.contains("401") { + let msg_lower = msg.to_ascii_lowercase(); + if msg_lower.contains("requires authentication") + || msg.contains("401") + || (msg.contains("400") + && (msg_lower.contains("authorization") || msg_lower.contains("authenticate"))) + { ExtensionError::AuthRequired } else { ExtensionError::ActivationFailed(msg) @@ -3444,7 +3451,8 @@ impl ExtensionManager { .or_else(|| relay_config.callback_url.clone()) .unwrap_or_else(|| { let host = std::env::var("GATEWAY_HOST").unwrap_or_else(|_| "127.0.0.1".into()); - let port = std::env::var("GATEWAY_PORT").unwrap_or_else(|_| "3001".into()); + let port = std::env::var("GATEWAY_PORT") + .unwrap_or_else(|_| crate::config::DEFAULT_GATEWAY_PORT.to_string()); format!("http://{}:{}", host, port) }); @@ -3843,11 +3851,12 @@ impl ExtensionManager { secret_name, name ))); } - if secret_value.trim().is_empty() { + let trimmed_value = secret_value.trim(); + if trimmed_value.is_empty() { continue; } let params = - CreateSecretParams::new(secret_name, secret_value).with_provider(name.to_string()); + CreateSecretParams::new(secret_name, trimmed_value).with_provider(name.to_string()); self.secrets .create(&self.user_id, params) .await diff --git a/src/main.rs b/src/main.rs index a7d95bec60..0b4695305a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -92,6 +92,10 @@ async fn async_main() -> anyhow::Result<()> { return ironclaw::cli::run_skills_command(skills_cmd.clone(), cli.config.as_deref()) .await; } + Some(Command::Logs(logs_cmd)) => { + init_cli_tracing(); + return ironclaw::cli::run_logs_command(logs_cmd.clone(), cli.config.as_deref()).await; + } Some(Command::Doctor) => { init_cli_tracing(); return ironclaw::cli::run_doctor_command().await; diff --git a/src/settings.rs b/src/settings.rs index 482291b6af..29bfbae169 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -220,7 +220,7 @@ pub struct TunnelSettings { } /// Channel-specific settings. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChannelSettings { /// Whether HTTP webhook channel is enabled. #[serde(default)] @@ -234,6 +234,30 @@ pub struct ChannelSettings { #[serde(default)] pub http_host: Option, + /// Whether the web gateway is enabled. + #[serde(default = "default_true")] + pub gateway_enabled: bool, + + /// Web gateway listen host. + #[serde(default)] + pub gateway_host: Option, + + /// Web gateway listen port. + #[serde(default)] + pub gateway_port: Option, + + /// Web gateway bearer auth token. Auto-generated at gateway startup if unset. + #[serde(default)] + pub gateway_auth_token: Option, + + /// Web gateway user ID. + #[serde(default)] + pub gateway_user_id: Option, + + /// Whether the CLI channel is enabled. + #[serde(default = "default_true")] + pub cli_enabled: bool, + /// Whether Signal channel is enabled. #[serde(default)] pub signal_enabled: bool, @@ -289,6 +313,34 @@ pub struct ChannelSettings { pub wasm_channels_dir: Option, } +impl Default for ChannelSettings { + fn default() -> Self { + Self { + http_enabled: false, + http_port: None, + http_host: None, + gateway_enabled: true, + gateway_host: None, + gateway_port: None, + gateway_auth_token: None, + gateway_user_id: None, + cli_enabled: true, + signal_enabled: false, + signal_http_url: None, + signal_account: None, + signal_allow_from: None, + signal_allow_from_groups: None, + signal_dm_policy: None, + signal_group_policy: None, + signal_group_allow_from: None, + wasm_channel_owner_ids: std::collections::HashMap::new(), + wasm_channels: Vec::new(), + wasm_channels_enabled: true, + wasm_channels_dir: None, + } + } +} + /// Heartbeat configuration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HeartbeatSettings { diff --git a/src/tools/builtin/skill_tools.rs b/src/tools/builtin/skill_tools.rs index a7581ac44e..457f16136a 100644 --- a/src/tools/builtin/skill_tools.rs +++ b/src/tools/builtin/skill_tools.rs @@ -301,7 +301,11 @@ impl Tool for SkillInstallTool { let content = if let Some(raw) = params.get("content").and_then(|v| v.as_str()) { // Direct content provided raw.to_string() - } else if let Some(url) = params.get("url").and_then(|v| v.as_str()) { + } else if let Some(url) = params + .get("url") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + { // Fetch from explicit URL fetch_skill_content(url).await? } else { @@ -1297,4 +1301,23 @@ mod tests { ); } } + + #[test] + fn test_empty_url_param_is_treated_as_absent() { + // LLMs sometimes pass "" for optional parameters instead of omitting them. + // Before the fix, url: "" would match Some("") and attempt to fetch from an + // empty URL (failing with an invalid URL error) instead of falling through to + // the catalog lookup. The full execute path cannot be tested here without a + // real catalog and database, so this test verifies the parameter filtering + // behaviour directly. + let params = serde_json::json!({"name": "my-skill", "url": ""}); + let url = params + .get("url") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()); + assert!( + url.is_none(), + "empty url string should be treated as absent" + ); + } } diff --git a/src/tools/mcp/auth.rs b/src/tools/mcp/auth.rs index 70df42ea2b..1926e78db8 100644 --- a/src/tools/mcp/auth.rs +++ b/src/tools/mcp/auth.rs @@ -24,7 +24,7 @@ use crate::tools::mcp::config::McpServerConfig; /// Per-request timeouts can override the default via `.timeout()` on /// the request builder. fn oauth_http_client() -> Result<&'static reqwest::Client, AuthError> { - static CLIENT: std::sync::OnceLock> = + static CLIENT: std::sync::OnceLock> = std::sync::OnceLock::new(); CLIENT .get_or_init(|| { @@ -32,10 +32,10 @@ fn oauth_http_client() -> Result<&'static reqwest::Client, AuthError> { .timeout(Duration::from_secs(30)) .redirect(reqwest::redirect::Policy::none()) .build() - .map_err(|e| e.to_string()) + .map_err(|e| AuthError::Http(e.to_string())) }) .as_ref() - .map_err(|e| AuthError::Http(e.clone())) + .map_err(Clone::clone) } /// Log a debug message when a discovery/auth response is a redirect. @@ -57,7 +57,7 @@ fn log_redirect_if_applicable(url: &str, response: &reqwest::Response) { } /// OAuth authorization error. -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Clone, thiserror::Error)] pub enum AuthError { #[error("Server does not support OAuth authorization")] NotSupported, @@ -443,6 +443,11 @@ async fn fetch_resource_metadata(url: &str) -> Result Result { validate_url_safe(server_url).await?; @@ -459,9 +464,13 @@ async fn discover_via_401(server_url: &str) -> Result Result assert_eq!(message, "builder failed"), // safety: test assertion in #[cfg(test)] module; not production panic path + other => panic!("expected AuthError::Http variant, got {other:?}"), + } + } + // --- New tests for well-known URI construction --- #[test] diff --git a/src/tools/mcp/client.rs b/src/tools/mcp/client.rs index 286ee63c12..c299ac49c9 100644 --- a/src/tools/mcp/client.rs +++ b/src/tools/mcp/client.rs @@ -275,7 +275,10 @@ impl McpClient { .keys() .any(|k| k.eq_ignore_ascii_case("authorization")); if !has_custom_auth && let Some(token) = self.get_access_token().await? { - headers.insert("Authorization".to_string(), format!("Bearer {}", token)); + let trimmed = token.trim(); + if !trimmed.is_empty() { + headers.insert("Authorization".to_string(), format!("Bearer {}", trimmed)); + } } if let Some(ref session_manager) = self.session_manager && let Some(session_id) = session_manager.get_session_id(&self.server_name).await @@ -302,7 +305,12 @@ impl McpClient { match result { Ok(response) => return Ok(response), Err(ToolError::ExternalService(ref msg)) - if msg.contains("401") || msg.contains("Unauthorized") => + if msg.contains("401") + || msg.contains("Unauthorized") + || (msg.contains("400") && { + let lower = msg.to_ascii_lowercase(); + lower.contains("authorization") || lower.contains("authenticate") + }) => { if attempt == 0 && let Some(ref secrets) = self.secrets @@ -1113,4 +1121,136 @@ mod tests { let approval = wrapper.requires_approval(&serde_json::json!({})); assert_eq!(approval, ApprovalRequirement::Never); } + + // Regression test: empty/whitespace-only tokens must not produce a + // malformed `Authorization: Bearer ` header (GitHub MCP returns 400 + // "Authorization header is badly formatted" in this case). + #[tokio::test] + async fn test_build_headers_skips_empty_token() { + use crate::secrets::{CreateSecretParams, DecryptedSecret, Secret, SecretError, SecretRef}; + use uuid::Uuid; + + // In-memory secrets store that returns a whitespace-only string for the token. + struct EmptyTokenStore; + #[async_trait] + impl crate::secrets::SecretsStore for EmptyTokenStore { + async fn create( + &self, + _user_id: &str, + _params: CreateSecretParams, + ) -> Result { + unimplemented!() + } + async fn get(&self, _user_id: &str, _name: &str) -> Result { + unimplemented!() + } + async fn get_decrypted( + &self, + _user_id: &str, + _name: &str, + ) -> Result { + DecryptedSecret::from_bytes(b" ".to_vec()) + } + async fn exists(&self, _user_id: &str, _name: &str) -> Result { + Ok(true) + } + async fn delete(&self, _user_id: &str, _name: &str) -> Result { + Ok(true) + } + async fn list(&self, _user_id: &str) -> Result, SecretError> { + Ok(Vec::new()) + } + async fn record_usage(&self, _secret_id: Uuid) -> Result<(), SecretError> { + Ok(()) + } + async fn is_accessible( + &self, + _user_id: &str, + _secret_name: &str, + _allowed_secrets: &[String], + ) -> Result { + Ok(true) + } + } + + let config = McpServerConfig::new("github", "https://api.githubcopilot.com/mcp/"); + let session_manager = Arc::new(McpSessionManager::new()); + let secrets: Arc = + Arc::new(EmptyTokenStore); + + let client = McpClient::new_authenticated(config, session_manager, secrets, "test-user"); + + let headers = client.build_request_headers().await.unwrap(); // safety: test + assert!( + // safety: test + !headers.contains_key("Authorization"), + "Empty/whitespace token must not produce an Authorization header, got: {:?}", + headers.get("Authorization") + ); + } + + // Regression test: tokens with leading/trailing whitespace must be trimmed + // before being used in the Authorization header. + #[tokio::test] + async fn test_build_headers_trims_token() { + use crate::secrets::{CreateSecretParams, DecryptedSecret, Secret, SecretError, SecretRef}; + use uuid::Uuid; + + struct PaddedTokenStore; + #[async_trait] + impl crate::secrets::SecretsStore for PaddedTokenStore { + async fn create( + &self, + _user_id: &str, + _params: CreateSecretParams, + ) -> Result { + unimplemented!() + } + async fn get(&self, _user_id: &str, _name: &str) -> Result { + unimplemented!() + } + async fn get_decrypted( + &self, + _user_id: &str, + _name: &str, + ) -> Result { + DecryptedSecret::from_bytes(b" gho_abc123 \n".to_vec()) + } + async fn exists(&self, _user_id: &str, _name: &str) -> Result { + Ok(true) + } + async fn delete(&self, _user_id: &str, _name: &str) -> Result { + Ok(true) + } + async fn list(&self, _user_id: &str) -> Result, SecretError> { + Ok(Vec::new()) + } + async fn record_usage(&self, _secret_id: Uuid) -> Result<(), SecretError> { + Ok(()) + } + async fn is_accessible( + &self, + _user_id: &str, + _secret_name: &str, + _allowed_secrets: &[String], + ) -> Result { + Ok(true) + } + } + + let config = McpServerConfig::new("github", "https://api.githubcopilot.com/mcp/"); + let session_manager = Arc::new(McpSessionManager::new()); + let secrets: Arc = + Arc::new(PaddedTokenStore); + + let client = McpClient::new_authenticated(config, session_manager, secrets, "test-user"); + + let headers = client.build_request_headers().await.unwrap(); // safety: test + assert_eq!( + // safety: test + headers.get("Authorization").unwrap(), // safety: test + "Bearer gho_abc123", + "Token must be trimmed before use in Authorization header" + ); + } } diff --git a/tests/e2e/mock_llm.py b/tests/e2e/mock_llm.py index 0fa0ce9f2d..175accf520 100644 --- a/tests/e2e/mock_llm.py +++ b/tests/e2e/mock_llm.py @@ -225,6 +225,128 @@ async def models(_request: web.Request) -> web.Response: }) +# ── Mock MCP Server ────────────────────────────────────────────────────────── +# +# Simulates an MCP server that requires OAuth. Unauthenticated requests get +# 401 + WWW-Authenticate (standard MCP flow) or 400 "Authorization header is +# badly formatted" (GitHub-style). Authenticated requests return valid +# JSON-RPC responses for initialize and tools/list. + + +async def mcp_endpoint(request: web.Request) -> web.Response: + """Handle POST /mcp — JSON-RPC MCP endpoint requiring Bearer auth.""" + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer ") or len(auth.split(" ", 1)[1].strip()) == 0: + # Return 401 with WWW-Authenticate header for OAuth discovery + resource_meta_url = f"http://127.0.0.1:{request.app['port']}/.well-known/oauth-protected-resource" + return web.Response( + status=401, + headers={"WWW-Authenticate": f'Bearer resource_metadata="{resource_meta_url}"'}, + text="Unauthorized", + ) + return await _mcp_handle_authed(request) + + +async def mcp_endpoint_400(request: web.Request) -> web.Response: + """Handle POST /mcp-400 — MCP endpoint that returns 400 (GitHub-style). + + Simulates GitHub's MCP server which returns 400 "Authorization header + is badly formatted" instead of 401 when auth is missing or invalid. + """ + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer ") or len(auth.split(" ", 1)[1].strip()) == 0: + return web.Response( + status=400, + text="bad request: Authorization header is badly formatted", + ) + return await _mcp_handle_authed(request) + + +async def _mcp_handle_authed(request: web.Request) -> web.Response: + """Handle an authenticated MCP JSON-RPC request.""" + body = await request.json() + method = body.get("method", "") + req_id = body.get("id") + + if method == "initialize": + return web.json_response({ + "jsonrpc": "2.0", "id": req_id, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "mock-mcp", "version": "1.0.0"}, + }, + }) + if method == "notifications/initialized": + return web.json_response({"jsonrpc": "2.0", "id": req_id, "result": {}}) + if method == "tools/list": + return web.json_response({ + "jsonrpc": "2.0", "id": req_id, + "result": {"tools": [{ + "name": "mock_search", + "description": "A mock search tool for testing", + "inputSchema": {"type": "object", "properties": { + "query": {"type": "string"}, + }}, + }]}, + }) + return web.json_response({"jsonrpc": "2.0", "id": req_id, "error": { + "code": -32601, "message": f"Method not found: {method}", + }}) + + +async def mcp_protected_resource(request: web.Request) -> web.Response: + """GET /.well-known/oauth-protected-resource[/{path}] — RFC 9728 discovery. + + Production code appends the MCP server path after the well-known suffix + (e.g. /.well-known/oauth-protected-resource/mcp-400), so this handler + accepts an optional tail and returns a resource matching the request. + """ + port = request.app["port"] + tail = request.match_info.get("tail", "mcp") + return web.json_response({ + "resource": f"http://127.0.0.1:{port}/{tail}", + "authorization_servers": [f"http://127.0.0.1:{port}"], + }) + + +async def mcp_auth_server_metadata(request: web.Request) -> web.Response: + """GET /.well-known/oauth-authorization-server[/{path}] — OAuth metadata.""" + port = request.app["port"] + base = f"http://127.0.0.1:{port}" + return web.json_response({ + "issuer": base, + "authorization_endpoint": f"{base}/oauth/authorize", + "token_endpoint": f"{base}/oauth/token", + "registration_endpoint": f"{base}/oauth/register", + "scopes_supported": ["read", "write"], + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + }) + + +async def mcp_oauth_register(request: web.Request) -> web.Response: + """POST /oauth/register — Dynamic Client Registration.""" + body = await request.json() + return web.json_response({ + "client_id": "mock-mcp-client-id", + "client_name": body.get("client_name", "IronClaw"), + "redirect_uris": body.get("redirect_uris", []), + }) + + +async def mcp_oauth_token(request: web.Request) -> web.Response: + """POST /oauth/token — Token endpoint for MCP OAuth.""" + data = await request.post() + code = data.get("code", "") + return web.json_response({ + "access_token": f"mcp-token-{code}", + "token_type": "Bearer", + "expires_in": 3600, + }) + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=0) @@ -236,6 +358,15 @@ def main(): app.router.add_get("/v1/models", models) app.router.add_get("/models", models) app.router.add_post("/oauth/exchange", oauth_exchange) + # Mock MCP server endpoints + app.router.add_post("/mcp", mcp_endpoint) + app.router.add_post("/mcp-400", mcp_endpoint_400) + app.router.add_get("/.well-known/oauth-protected-resource", mcp_protected_resource) + app.router.add_get("/.well-known/oauth-protected-resource/{tail:.*}", mcp_protected_resource) + app.router.add_get("/.well-known/oauth-authorization-server", mcp_auth_server_metadata) + app.router.add_get("/.well-known/oauth-authorization-server/{tail:.*}", mcp_auth_server_metadata) + app.router.add_post("/oauth/register", mcp_oauth_register) + app.router.add_post("/oauth/token", mcp_oauth_token) async def start(): runner = web.AppRunner(app) @@ -243,6 +374,7 @@ async def start(): site = web.TCPSite(runner, "127.0.0.1", args.port) await site.start() port = site._server.sockets[0].getsockname()[1] + app["port"] = port # used by MCP handlers print(f"MOCK_LLM_PORT={port}", flush=True) await asyncio.Event().wait() diff --git a/tests/e2e/scenarios/test_chat.py b/tests/e2e/scenarios/test_chat.py index 24b3d98d7a..440eb18efd 100644 --- a/tests/e2e/scenarios/test_chat.py +++ b/tests/e2e/scenarios/test_chat.py @@ -74,3 +74,44 @@ async def test_empty_message_not_sent(page): await page.wait_for_timeout(2000) final_count = await page.locator(f"{SEL['message_user']}, {SEL['message_assistant']}").count() assert final_count == initial_count, "Empty message should not create new messages" + + +async def test_copy_from_chat_forces_plain_text(page): + """Copying selected chat text should populate plain text clipboard data only.""" + await page.evaluate("addMessage('assistant', 'Copy me into Sheets')") + + copied = await page.evaluate( + """ + () => { + const content = Array.from(document.querySelectorAll('#chat-messages .message.assistant .message-content')) + .find((el) => (el.textContent || '').includes('Copy me into Sheets')); + if (!content) return {ok: false, reason: 'no content'}; + const range = document.createRange(); + range.selectNodeContents(content); + const sel = window.getSelection(); + sel.removeAllRanges(); + sel.addRange(range); + + const store = {}; + const evt = new Event('copy', { bubbles: true, cancelable: true }); + evt.clipboardData = { + clearData: () => { Object.keys(store).forEach((k) => delete store[k]); }, + setData: (t, v) => { store[t] = v; }, + getData: (t) => store[t] || '', + }; + + content.dispatchEvent(evt); + return { + ok: true, + defaultPrevented: evt.defaultPrevented, + text: store['text/plain'] || '', + html: store['text/html'] || '', + }; + } + """ + ) + + assert copied["ok"], copied.get("reason", "copy setup failed") + assert copied["defaultPrevented"] is True + assert "Copy me into Sheets" in copied["text"] + assert copied["html"] == "" diff --git a/tests/e2e/scenarios/test_mcp_auth_flow.py b/tests/e2e/scenarios/test_mcp_auth_flow.py new file mode 100644 index 0000000000..7de2bbe689 --- /dev/null +++ b/tests/e2e/scenarios/test_mcp_auth_flow.py @@ -0,0 +1,355 @@ +"""MCP server auth flow E2E tests. + +Tests the full MCP server lifecycle: install MCP server (pointing at mock) -> +activate triggers auth (401/400 -> AuthRequired -> OAuth URL) -> OAuth callback +completes -> auth mode cleared (next message triggers LLM turn) -> MCP tools +available. + +Regression coverage for: + - 400 "Authorization header is badly formatted" treated as auth-required + - OAuth discovery via 401 + WWW-Authenticate header + - clear_auth_mode after OAuth callback (user message not swallowed) + - Token trimming (whitespace/newline in stored tokens) + +The mock_llm.py serves a mock MCP server at /mcp with full OAuth discovery +endpoints (.well-known/oauth-protected-resource, DCR, token exchange). +""" + +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest + +from helpers import SEL, api_get, api_post + + +def _extract_state(auth_url: str) -> str: + """Extract the CSRF state parameter from an OAuth authorization URL.""" + parsed = urlparse(auth_url) + qs = parse_qs(parsed.query) + assert "state" in qs, f"auth_url should contain state param: {auth_url}" + return qs["state"][0] + + +async def _get_extension(base_url, name): + """Get a specific extension from the extensions list, or None.""" + r = await api_get(base_url, "/api/extensions") + for ext in r.json().get("extensions", []): + if ext["name"] == name: + return ext + return None + + +async def _ensure_removed(base_url, name): + """Remove extension if already installed.""" + ext = await _get_extension(base_url, name) + if ext: + await api_post(base_url, f"/api/extensions/{name}/remove", timeout=30) + + +# ── Section A: Install MCP Server ──────────────────────────────────────── + + +async def test_mcp_install(ironclaw_server, mock_llm_server): + """Install a mock MCP server pointing at mock_llm.py's /mcp endpoint.""" + await _ensure_removed(ironclaw_server, "mock-mcp") + + mcp_url = f"{mock_llm_server}/mcp" + r = await api_post( + ironclaw_server, + "/api/extensions/install", + json={"name": "mock-mcp", "url": mcp_url, "kind": "mcp_server"}, + timeout=30, + ) + assert r.status_code == 200 + data = r.json() + assert data.get("success") is True, f"Install failed: {data}" + + ext = await _get_extension(ironclaw_server, "mock-mcp") + assert ext is not None, "mock-mcp should appear in extensions list" + assert ext["kind"] == "mcp_server" + + +# ── Section B: Activate Triggers Auth ──────────────────────────────────── + + +async def test_mcp_activate_triggers_auth(ironclaw_server): + """Activating an unauthenticated MCP server triggers the OAuth flow. + + The mock MCP returns 401 with WWW-Authenticate when no Bearer token + is present. The activate handler should detect this as auth-required + and return an auth_url. + """ + ext = await _get_extension(ironclaw_server, "mock-mcp") + if ext is None: + pytest.skip("mock-mcp not installed") + + r = await api_post( + ironclaw_server, + "/api/extensions/mock-mcp/activate", + timeout=30, + ) + assert r.status_code == 200 + data = r.json() + + # Activation should fail with an auth_url (OAuth needed) + # OR it should return awaiting_token (manual token prompt) + auth_url = data.get("auth_url") + awaiting_token = data.get("awaiting_token") + assert auth_url is not None or awaiting_token, ( + f"Activate should require auth, got: {data}" + ) + + +# ── Section C: OAuth Round-Trip ────────────────────────────────────────── + + +async def test_mcp_oauth_callback(ironclaw_server): + """Complete the OAuth flow via setup + callback for the MCP server.""" + ext = await _get_extension(ironclaw_server, "mock-mcp") + if ext is None: + pytest.skip("mock-mcp not installed") + + # Configure with empty secrets to trigger OAuth + r = await api_post( + ironclaw_server, + "/api/extensions/mock-mcp/setup", + json={"secrets": {}}, + timeout=30, + ) + assert r.status_code == 200 + data = r.json() + + # If no auth_url, try activate to trigger it + auth_url = data.get("auth_url") + if auth_url is None: + r = await api_post( + ironclaw_server, + "/api/extensions/mock-mcp/activate", + timeout=30, + ) + data = r.json() + auth_url = data.get("auth_url") + + if auth_url is None: + # Server might have been auto-authenticated via DCR; check if active + ext = await _get_extension(ironclaw_server, "mock-mcp") + if ext and ext.get("authenticated"): + return # Already authenticated, skip callback test + pytest.skip("Could not obtain auth_url for mock-mcp") + + csrf_state = _extract_state(auth_url) + + # Hit the OAuth callback endpoint + async with httpx.AsyncClient() as client: + r = await client.get( + f"{ironclaw_server}/oauth/callback", + params={"code": "mock_mcp_code", "state": csrf_state}, + timeout=30, + follow_redirects=True, + ) + assert r.status_code == 200, f"Callback returned {r.status_code}: {r.text[:300]}" + body = r.text.lower() + assert "connected" in body or "success" in body, ( + f"Callback should indicate success: {r.text[:500]}" + ) + + +async def test_mcp_authenticated_after_oauth(ironclaw_server): + """After OAuth callback, MCP server shows authenticated=True.""" + ext = await _get_extension(ironclaw_server, "mock-mcp") + if ext is None: + pytest.skip("mock-mcp not installed") + assert ext["authenticated"] is True, ( + f"mock-mcp should be authenticated after OAuth: {ext}" + ) + + +async def test_mcp_tools_registered(ironclaw_server): + """After authentication, MCP tools appear in the extension.""" + ext = await _get_extension(ironclaw_server, "mock-mcp") + if ext is None: + pytest.skip("mock-mcp not installed") + tools = ext.get("tools", []) + assert len(tools) > 0, f"mock-mcp should have tools after auth: {ext}" + # The mock MCP serves a tool named "mock_search", prefixed with server name + tool_names = [t for t in tools if "mock_search" in t] + assert len(tool_names) > 0, f"Expected mock_search tool, got: {tools}" + + +# ── Section D: Auth Mode Cleared — LLM Turn Fires ─────────────────────── + + +async def test_mcp_auth_mode_cleared_llm_turn_fires(ironclaw_server, page): + """After OAuth completes, the next user message triggers an LLM turn. + + Regression test: previously, pending_auth was not cleared by the OAuth + callback handler, so the next user message was consumed as a token and + the LLM turn never fired. + """ + chat_input = page.locator(SEL["chat_input"]) + await chat_input.wait_for(state="visible", timeout=5000) + + assistant_sel = SEL["message_assistant"] + before_count = await page.locator(assistant_sel).count() + + # Send a normal message — should trigger LLM, not be swallowed by auth + await chat_input.fill("hello") + await chat_input.press("Enter") + + # Wait for assistant response + expected = before_count + 1 + await page.wait_for_function( + """({ assistantSelector, expectedCount }) => { + const messages = document.querySelectorAll(assistantSelector); + return messages.length >= expectedCount; + }""", + arg={"assistantSelector": assistant_sel, "expectedCount": expected}, + timeout=15000, + ) + + text = await page.locator(assistant_sel).last.inner_text() + assert len(text.strip()) > 0, "Assistant should have responded" + + +# ── Section E: GitHub-style 400 Error ───────────────────────────────────── + + +async def test_mcp_400_activate_triggers_auth(ironclaw_server, mock_llm_server): + """MCP server returning 400 "Authorization header is badly formatted" + is treated as auth-required (regression for GitHub MCP). + + Previously, only 401 triggered the auth flow. GitHub's MCP returns 400 + with "Authorization header is badly formatted" instead. + """ + await _ensure_removed(ironclaw_server, "mock-mcp-400") + + mcp_url = f"{mock_llm_server}/mcp-400" + r = await api_post( + ironclaw_server, + "/api/extensions/install", + json={"name": "mock-mcp-400", "url": mcp_url, "kind": "mcp_server"}, + timeout=30, + ) + assert r.status_code == 200 + assert r.json().get("success") is True, f"Install failed: {r.json()}" + + # Activate should detect 400 + "authorization" as auth-required + r = await api_post( + ironclaw_server, + "/api/extensions/mock-mcp-400/activate", + timeout=30, + ) + assert r.status_code == 200, f"Activate returned {r.status_code}: {r.text[:300]}" + data = r.json() + + # The 400 should be treated as auth-required, returning an auth_url + # or awaiting_token — not a raw "400 Bad Request" activation error. + auth_url = data.get("auth_url") + awaiting_token = data.get("awaiting_token") + assert auth_url is not None or awaiting_token, ( + f"400 auth error should trigger auth flow (auth_url or awaiting_token), got: {data}" + ) + + +async def test_mcp_400_oauth_discovery_returns_auth_url(ironclaw_server): + """OAuth discovery succeeds for the 400-variant via RFC 9728 (strategy 2). + + Strategy 1 (discover_via_401) fails because /mcp-400 returns 400 without + a WWW-Authenticate header. Strategy 2 queries + /.well-known/oauth-protected-resource/mcp-400 (path-suffixed) and must + find the mock's wildcard route. Without that route, discovery fails + entirely and only awaiting_token (manual) is returned — no auth_url. + + This test would have failed before the wildcard .well-known routes were + added to mock_llm.py. + """ + ext = await _get_extension(ironclaw_server, "mock-mcp-400") + if ext is None: + pytest.skip("mock-mcp-400 not installed") + + # Re-activate to get a fresh auth response + r = await api_post( + ironclaw_server, + "/api/extensions/mock-mcp-400/activate", + timeout=30, + ) + assert r.status_code == 200, f"Activate returned {r.status_code}: {r.text[:300]}" + data = r.json() + + auth_url = data.get("auth_url") + assert auth_url is not None, ( + f"OAuth discovery must produce an auth_url (not just awaiting_token). " + f"Strategy 2 (RFC 9728) likely failed — check .well-known wildcard routes. " + f"Got: {data}" + ) + + +async def test_mcp_400_full_oauth_roundtrip(ironclaw_server): + """Complete OAuth round-trip for the 400-variant MCP server. + + Exercises the full path: activate → 400 detected as auth-required → + OAuth discovery via strategy 2 (path-suffixed .well-known) → DCR → + auth_url returned → callback completes token exchange → extension + authenticated with tools. + + Without the wildcard .well-known routes, OAuth discovery fails and + no auth_url is produced, so this test would fail at the csrf_state + extraction step. + """ + ext = await _get_extension(ironclaw_server, "mock-mcp-400") + if ext is None: + pytest.skip("mock-mcp-400 not installed") + + # Get a fresh auth_url via activate + r = await api_post( + ironclaw_server, + "/api/extensions/mock-mcp-400/activate", + timeout=30, + ) + data = r.json() + auth_url = data.get("auth_url") + if auth_url is None: + pytest.skip("No auth_url from activate (discovery may not have succeeded)") + + csrf_state = _extract_state(auth_url) + + # Complete OAuth callback + async with httpx.AsyncClient() as client: + r = await client.get( + f"{ironclaw_server}/oauth/callback", + params={"code": "mock_400_code", "state": csrf_state}, + timeout=30, + follow_redirects=True, + ) + assert r.status_code == 200, f"Callback returned {r.status_code}: {r.text[:300]}" + body = r.text.lower() + assert "connected" in body or "success" in body, ( + f"400-variant OAuth callback should succeed: {r.text[:500]}" + ) + + # Verify authenticated + tools loaded + ext = await _get_extension(ironclaw_server, "mock-mcp-400") + assert ext is not None, "mock-mcp-400 should still be installed" + assert ext["authenticated"] is True, ( + f"mock-mcp-400 should be authenticated after OAuth: {ext}" + ) + tools = ext.get("tools", []) + assert len(tools) > 0, f"mock-mcp-400 should have tools after auth: {ext}" + + +async def test_mcp_400_cleanup(ironclaw_server): + """Clean up the 400-variant MCP server.""" + await _ensure_removed(ironclaw_server, "mock-mcp-400") + ext = await _get_extension(ironclaw_server, "mock-mcp-400") + assert ext is None, "mock-mcp-400 should be removed" + + +# ── Section F: Cleanup ─────────────────────────────────────────────────── + + +async def test_mcp_cleanup(ironclaw_server): + """Remove mock-mcp (cleanup for other test files).""" + await _ensure_removed(ironclaw_server, "mock-mcp") + ext = await _get_extension(ironclaw_server, "mock-mcp") + assert ext is None, "mock-mcp should be removed"