diff --git a/src/agent/thread_ops.rs b/src/agent/thread_ops.rs index e52d515c9f..721f4dd268 100644 --- a/src/agent/thread_ops.rs +++ b/src/agent/thread_ops.rs @@ -371,70 +371,73 @@ impl Agent { match thread_state { ThreadState::Processing => { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - // Re-check state under lock — the turn may have completed - // between the snapshot read and this mutable lock acquisition. - if thread.state == ThreadState::Processing { - // Reject messages with attachments — the queue stores - // text only, so attachments would be silently dropped. - if !message.attachments.is_empty() { - return Ok(SubmissionResult::error( - "Cannot queue messages with attachments while a turn is processing. \ - Please resend after the current turn completes.", - )); - } + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + // Re-check state under lock — the turn may have completed + // between the snapshot read and this mutable lock acquisition. + if thread.state == ThreadState::Processing { + // Reject messages with attachments — the queue stores + // text only, so attachments would be silently dropped. + if !message.attachments.is_empty() { + return Ok(SubmissionResult::error( + "Cannot queue messages with attachments while a turn is processing. \ + Please resend after the current turn completes.", + )); + } - // Run the same safety checks that the normal path applies - // (validation, policy, secret scan) so that blocked content - // is never stored in pending_messages or serialized. - let validation = self.safety().validate_input(content); - if !validation.is_valid { - let details = validation - .errors + // Run the same safety checks that the normal path applies + // (validation, policy, secret scan) so that blocked content + // is never stored in pending_messages or serialized. + let validation = self.safety().validate_input(content); + if !validation.is_valid { + let details = validation + .errors + .iter() + .map(|e| format!("{}: {}", e.field, e.message)) + .collect::>() + .join("; "); + return Ok(SubmissionResult::error(format!( + "Input rejected by safety validation: {details}", + ))); + } + let violations = self.safety().check_policy(content); + if violations .iter() - .map(|e| format!("{}: {}", e.field, e.message)) - .collect::>() - .join("; "); - return Ok(SubmissionResult::error(format!( - "Input rejected by safety validation: {details}", - ))); - } - let violations = self.safety().check_policy(content); - if violations - .iter() - .any(|rule| rule.action == ironclaw_safety::PolicyAction::Block) - { - return Ok(SubmissionResult::error("Input rejected by safety policy.")); - } - if let Some(warning) = self.safety().scan_inbound_for_secrets(content) { - tracing::warn!( - user = %message.user_id, - channel = %message.channel, - "Queued message blocked: contains leaked secret" - ); - return Ok(SubmissionResult::error(warning)); - } + .any(|rule| rule.action == ironclaw_safety::PolicyAction::Block) + { + return Ok(SubmissionResult::error("Input rejected by safety policy.")); + } + if let Some(warning) = self.safety().scan_inbound_for_secrets(content) { + tracing::warn!( + user = %message.user_id, + channel = %message.channel, + "Queued message blocked: contains leaked secret" + ); + return Ok(SubmissionResult::error(warning)); + } - if !thread.queue_message(content.to_string()) { - return Ok(SubmissionResult::error(format!( - "Message queue full ({MAX_PENDING_MESSAGES}). Wait for the current turn to complete.", - ))); + if !thread.queue_message(content.to_string()) { + return Ok(SubmissionResult::error(format!( + "Message queue full ({MAX_PENDING_MESSAGES}). Wait for the current turn to complete.", + ))); + } + // Return `Ok` (not `Response`) so the drain loop in + // agent_loop.rs breaks — `Ok` signals a control + // acknowledgment, not a completed LLM turn. + return Ok(SubmissionResult::Ok { + message: Some( + "Message queued — will be processed after the current turn.".into(), + ), + }); } - // Return `Ok` (not `Response`) so the drain loop in - // agent_loop.rs breaks — `Ok` signals a control - // acknowledgment, not a completed LLM turn. - return Ok(SubmissionResult::Ok { - message: Some( - "Message queued — will be processed after the current turn.".into(), - ), - }); + // State changed (turn completed) — fall through to process normally. + // NOTE: `sess` (the Mutex guard) is dropped at the end of + // this `Processing` match arm, releasing the session lock + // before the rest of process_user_input runs. No deadlock. + } + None => { + return Ok(SubmissionResult::error("Thread no longer exists.")); } - // State changed (turn completed) — fall through to process normally. - // NOTE: `sess` (the Mutex guard) is dropped at the end of - // this `Processing` match arm, releasing the session lock - // before the rest of process_user_input runs. No deadlock. - } else { - return Ok(SubmissionResult::error("Thread no longer exists.")); } } ThreadState::AwaitingApproval => { @@ -1159,7 +1162,8 @@ impl Agent { approved: bool, always: bool, ) -> Result { - // Get pending approval for this thread + // Take + verify under a single lock to prevent TOCTOU races where a + // concurrent operation could modify the thread between take and restore. let pending = { let mut sess = session.lock().await; let thread = sess @@ -1177,33 +1181,30 @@ impl Agent { return Ok(SubmissionResult::ok_with_message("")); } - thread.take_pending_approval() - }; - - let pending = match pending { - Some(p) => p, - None => { - tracing::debug!( - %thread_id, - "Ignoring stale approval: no pending approval found" - ); - return Ok(SubmissionResult::ok_with_message("")); - } - }; + let pending = match thread.take_pending_approval() { + Some(p) => p, + None => { + tracing::debug!( + %thread_id, + "Ignoring stale approval: no pending approval found" + ); + return Ok(SubmissionResult::ok_with_message("")); + } + }; - // Verify request ID if provided - if let Some(req_id) = request_id - && req_id != pending.request_id - { - // Put it back and return error - let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { + // Verify request ID while still holding the lock so the pending + // approval is atomically restored on mismatch. + if let Some(req_id) = request_id + && req_id != pending.request_id + { thread.await_approval(pending); + return Ok(SubmissionResult::error( + "Request ID mismatch. Use the correct request ID.", + )); } - return Ok(SubmissionResult::error( - "Request ID mismatch. Use the correct request ID.", - )); - } + + pending + }; if approved { // If always, add to auto-approved set and persist to settings. @@ -1266,8 +1267,13 @@ impl Agent { // Reset thread state to processing { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.state = ThreadState::Processing; + match sess.threads.get_mut(&thread_id) { + Some(thread) => thread.state = ThreadState::Processing, + None => { + return Err(Error::from(crate::error::JobError::NotFound { + id: thread_id, + })); + } } } @@ -1358,15 +1364,26 @@ impl Agent { // Record sanitized result in thread { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) - && let Some(turn) = thread.last_turn_mut() - { - if is_tool_error { - turn.record_tool_error_for(&pending.tool_call_id, result_content.clone()); - } else { - turn.record_tool_result_for( - &pending.tool_call_id, - serde_json::json!(result_content), + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + if let Some(turn) = thread.last_turn_mut() { + if is_tool_error { + turn.record_tool_error_for( + &pending.tool_call_id, + result_content.clone(), + ); + } else { + turn.record_tool_result_for( + &pending.tool_call_id, + serde_json::json!(result_content), + ); + } + } + } + None => { + tracing::debug!( + %thread_id, + "Thread disappeared before tool result could be recorded" ); } } @@ -1606,15 +1623,24 @@ impl Agent { // Record sanitized result in thread { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) - && let Some(turn) = thread.last_turn_mut() - { - if is_deferred_error { - turn.record_tool_error_for(&tc.id, deferred_content.clone()); - } else { - turn.record_tool_result_for( - &tc.id, - serde_json::json!(deferred_content), + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + if let Some(turn) = thread.last_turn_mut() { + if is_deferred_error { + turn.record_tool_error_for(&tc.id, deferred_content.clone()); + } else { + turn.record_tool_result_for( + &tc.id, + serde_json::json!(deferred_content), + ); + } + } + } + None => { + tracing::debug!( + %thread_id, + tool = %tc.name, + "Thread disappeared before deferred tool result could be recorded" ); } } @@ -1665,8 +1691,13 @@ impl Agent { { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.await_approval(new_pending); + match sess.threads.get_mut(&thread_id) { + Some(thread) => thread.await_approval(new_pending), + None => { + return Err(Error::from(crate::error::JobError::NotFound { + id: thread_id, + })); + } } } @@ -1868,17 +1899,24 @@ impl Agent { ); { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.clear_pending_approval(); - thread.complete_turn(&rejection); - // User message already persisted at turn start; save rejection response - self.persist_assistant_response( - thread_id, - &message.channel, - &message.user_id, - &rejection, - ) - .await; + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + thread.clear_pending_approval(); + thread.complete_turn(&rejection); + // User message already persisted at turn start; save rejection response + self.persist_assistant_response( + thread_id, + &message.channel, + &message.user_id, + &rejection, + ) + .await; + } + None => { + return Err(Error::from(crate::error::JobError::NotFound { + id: thread_id, + })); + } } } @@ -1911,17 +1949,25 @@ impl Agent { ) { { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.enter_auth_mode(ext_name.clone()); - thread.complete_turn(&instructions); - // User message already persisted at turn start; save auth instructions - self.persist_assistant_response( - thread_id, - &message.channel, - &message.user_id, - &instructions, - ) - .await; + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + thread.enter_auth_mode(ext_name.clone()); + thread.complete_turn(&instructions); + // User message already persisted at turn start; save auth instructions + self.persist_assistant_response( + thread_id, + &message.channel, + &message.user_id, + &instructions, + ) + .await; + } + None => { + tracing::debug!( + %thread_id, + "Thread disappeared before auth intercept could be applied" + ); + } } } emit_auth_required_status( @@ -1978,8 +2024,14 @@ impl Agent { // Clear auth mode regardless of outcome { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.pending_auth = None; + match sess.threads.get_mut(&thread_id) { + Some(thread) => thread.pending_auth = None, + None => { + tracing::debug!( + %thread_id, + "Thread disappeared before auth mode could be cleared" + ); + } } } @@ -2031,8 +2083,16 @@ impl Agent { Ok(result) => { { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.enter_auth_mode(pending.extension_name.clone()); + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + thread.enter_auth_mode(pending.extension_name.clone()); + } + None => { + tracing::debug!( + %thread_id, + "Thread disappeared before auth mode could be re-entered" + ); + } } } emit_auth_required_status( @@ -2056,8 +2116,16 @@ impl Agent { ); { let mut sess = session.lock().await; - if let Some(thread) = sess.threads.get_mut(&thread_id) { - thread.enter_auth_mode(pending.extension_name.clone()); + match sess.threads.get_mut(&thread_id) { + Some(thread) => { + thread.enter_auth_mode(pending.extension_name.clone()); + } + None => { + tracing::debug!( + %thread_id, + "Thread disappeared before auth mode could be re-entered on retry" + ); + } } } emit_auth_required_status( @@ -3263,4 +3331,110 @@ mod tests { Ok(None) } } + + /// Regression test for #1486: TOCTOU race in process_approval. + /// + /// After a request_id mismatch the pending approval must be atomically + /// restored so a subsequent approval with the correct ID succeeds. + #[tokio::test] + async fn test_approval_request_id_mismatch_restores_pending() { + use crate::agent::session::{PendingApproval, Session, Thread, ThreadState}; + use uuid::Uuid; + + let session_id = Uuid::new_v4(); + let thread_id = Uuid::new_v4(); + let correct_request_id = Uuid::new_v4(); + let wrong_request_id = Uuid::new_v4(); + + let mut thread = Thread::with_id(thread_id, session_id, None); + let pending = PendingApproval { + request_id: correct_request_id, + tool_name: "echo".to_string(), + parameters: serde_json::json!({"text": "hello"}), + display_parameters: serde_json::json!({"text": "hello"}), + description: "Echo hello".to_string(), + tool_call_id: "call_0".to_string(), + context_messages: vec![], + deferred_tool_calls: vec![], + selected_auth_prompt: None, + user_timezone: None, + allow_always: false, + }; + thread.await_approval(pending); + assert_eq!(thread.state, ThreadState::AwaitingApproval); + + let mut session = Session::new("test-user"); + session.threads.insert(thread_id, thread); + let session = Arc::new(Mutex::new(session)); + + let (agent, _statuses) = make_thread_ops_test_agent().await; + + let message = IncomingMessage::new("test", "test-user", "yes"); + + // Attempt approval with WRONG request ID — should fail but preserve pending + let result = agent + .process_approval( + &message, + session.clone(), + thread_id, + Some(wrong_request_id), + true, + false, + ) + .await; + + assert!(result.is_ok()); + let result = result.unwrap(); + match &result { + SubmissionResult::Error { message } => { + assert!( + message.contains("Request ID mismatch"), + "Expected mismatch error, got: {}", + message + ); + } + other => panic!("Expected Error result, got: {:?}", other), + } + + // Verify pending approval is still present (not lost due to TOCTOU) + let sess = session.lock().await; + let thread = sess.threads.get(&thread_id).unwrap(); + assert_eq!(thread.state, ThreadState::AwaitingApproval); + assert!( + thread.pending_approval.is_some(), + "Pending approval should be restored after request_id mismatch" + ); + assert_eq!( + thread.pending_approval.as_ref().unwrap().request_id, + correct_request_id, + "Restored pending approval should have the original request_id" + ); + } + + /// Regression test for #1487: process_approval on a missing thread should error. + #[tokio::test] + async fn test_approval_on_missing_thread_should_error() { + use crate::agent::session::Session; + use uuid::Uuid; + + let thread_id = Uuid::new_v4(); + + // Session with NO threads + let session = Session::new("test-user"); + let session = Arc::new(Mutex::new(session)); + + let (agent, _statuses) = make_thread_ops_test_agent().await; + + let message = IncomingMessage::new("test", "test-user", "yes"); + + let result = agent + .process_approval(&message, session, thread_id, None, true, false) + .await; + + assert!( + result.is_err(), + "Approving a missing thread should return an error, got: {:?}", + result + ); + } }