|
| 1 | +use anyhow::Result; |
| 2 | +use async_trait::async_trait; |
| 3 | +use futures::StreamExt; |
| 4 | +use goose::agents::{Agent, AgentEvent, SessionConfig}; |
| 5 | +use goose::conversation::message::{Message, MessageContent}; |
| 6 | +use goose::conversation::Conversation; |
| 7 | +use goose::model::ModelConfig; |
| 8 | +use goose::providers::base::{ |
| 9 | + stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, |
| 10 | + ProviderUsage, Usage, |
| 11 | +}; |
| 12 | +use goose::providers::errors::ProviderError; |
| 13 | +use goose::session::session_manager::SessionType; |
| 14 | +use goose::session::Session; |
| 15 | +use rmcp::model::{AnnotateAble, CallToolRequestParams, RawContent, Tool}; |
| 16 | +use serial_test::serial; |
| 17 | +use std::sync::Arc; |
| 18 | +use tempfile::TempDir; |
| 19 | + |
| 20 | +// --------------------------------------------------------------------------- |
| 21 | +// Mock provider that recognises summarization calls via the system prompt |
| 22 | +// --------------------------------------------------------------------------- |
| 23 | + |
| 24 | +struct MockSummarizationProvider; |
| 25 | + |
| 26 | +impl MockSummarizationProvider { |
| 27 | + fn new() -> Self { |
| 28 | + Self |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +#[async_trait] |
| 33 | +impl Provider for MockSummarizationProvider { |
| 34 | + async fn stream( |
| 35 | + &self, |
| 36 | + _model_config: &ModelConfig, |
| 37 | + _session_id: &str, |
| 38 | + system_prompt: &str, |
| 39 | + _messages: &[Message], |
| 40 | + _tools: &[Tool], |
| 41 | + ) -> Result<MessageStream, ProviderError> { |
| 42 | + // complete_fast → complete → stream; the summarization path passes the |
| 43 | + // indoc system prompt containing "summarize a tool call". |
| 44 | + let is_summarization = system_prompt |
| 45 | + .to_lowercase() |
| 46 | + .contains("summarize a tool call"); |
| 47 | + |
| 48 | + let message = if is_summarization { |
| 49 | + Message::assistant().with_text("A call to shell was made to list files") |
| 50 | + } else { |
| 51 | + // Regular reply — no tool requests so the agent loop exits. |
| 52 | + Message::assistant().with_text("Done.") |
| 53 | + }; |
| 54 | + |
| 55 | + let usage = ProviderUsage::new( |
| 56 | + "mock-model".to_string(), |
| 57 | + Usage::new(Some(100), Some(50), Some(150)), |
| 58 | + ); |
| 59 | + |
| 60 | + Ok(stream_from_single_message(message, usage)) |
| 61 | + } |
| 62 | + |
| 63 | + fn get_model_config(&self) -> ModelConfig { |
| 64 | + ModelConfig::new("mock-model").unwrap() |
| 65 | + } |
| 66 | + |
| 67 | + fn get_name(&self) -> &str { |
| 68 | + "mock-summarization" |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +impl ProviderDef for MockSummarizationProvider { |
| 73 | + type Provider = Self; |
| 74 | + |
| 75 | + fn metadata() -> ProviderMetadata { |
| 76 | + ProviderMetadata { |
| 77 | + name: "mock".to_string(), |
| 78 | + display_name: "Mock Summarization Provider".to_string(), |
| 79 | + description: "Mock provider for tool-pair summarization testing".to_string(), |
| 80 | + default_model: "mock-model".to_string(), |
| 81 | + known_models: vec![], |
| 82 | + model_doc_link: "".to_string(), |
| 83 | + config_keys: vec![], |
| 84 | + allows_unlisted_models: false, |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + fn from_env( |
| 89 | + _model: ModelConfig, |
| 90 | + _extensions: Vec<goose::config::ExtensionConfig>, |
| 91 | + ) -> futures::future::BoxFuture<'static, anyhow::Result<Self>> { |
| 92 | + Box::pin(async { Ok(Self::new()) }) |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +// --------------------------------------------------------------------------- |
| 97 | +// Helpers |
| 98 | +// --------------------------------------------------------------------------- |
| 99 | + |
| 100 | +/// Build a tool-request / tool-response pair linked by `call_id`. |
| 101 | +/// Both messages carry `.with_id()` — required by the `msg.id.is_some()` |
| 102 | +/// guard at agent.rs:1586. |
| 103 | +fn create_tool_pair( |
| 104 | + call_id: &str, |
| 105 | + response_id: &str, |
| 106 | + tool_name: &str, |
| 107 | + response_text: &str, |
| 108 | +) -> Vec<Message> { |
| 109 | + vec![ |
| 110 | + Message::assistant() |
| 111 | + .with_tool_request( |
| 112 | + call_id, |
| 113 | + Ok(CallToolRequestParams { |
| 114 | + task: None, |
| 115 | + name: tool_name.to_string().into(), |
| 116 | + arguments: None, |
| 117 | + meta: None, |
| 118 | + }), |
| 119 | + ) |
| 120 | + .with_id(call_id), |
| 121 | + Message::user() |
| 122 | + .with_tool_response( |
| 123 | + call_id, |
| 124 | + Ok(rmcp::model::CallToolResult { |
| 125 | + content: vec![RawContent::text(response_text).no_annotation()], |
| 126 | + structured_content: None, |
| 127 | + is_error: Some(false), |
| 128 | + meta: None, |
| 129 | + }), |
| 130 | + ) |
| 131 | + .with_id(response_id), |
| 132 | + ] |
| 133 | +} |
| 134 | + |
| 135 | +/// Set up a session pre-populated with `messages` and sensible token counts. |
| 136 | +async fn setup_test_session( |
| 137 | + agent: &Agent, |
| 138 | + temp_dir: &TempDir, |
| 139 | + session_name: &str, |
| 140 | + messages: Vec<Message>, |
| 141 | +) -> Result<Session> { |
| 142 | + let session = agent |
| 143 | + .config |
| 144 | + .session_manager |
| 145 | + .create_session( |
| 146 | + temp_dir.path().to_path_buf(), |
| 147 | + session_name.to_string(), |
| 148 | + SessionType::Hidden, |
| 149 | + ) |
| 150 | + .await?; |
| 151 | + |
| 152 | + let conversation = Conversation::new_unvalidated(messages); |
| 153 | + agent |
| 154 | + .config |
| 155 | + .session_manager |
| 156 | + .replace_conversation(&session.id, &conversation) |
| 157 | + .await?; |
| 158 | + |
| 159 | + agent |
| 160 | + .config |
| 161 | + .session_manager |
| 162 | + .update(&session.id) |
| 163 | + .total_tokens(Some(1000)) |
| 164 | + .input_tokens(Some(600)) |
| 165 | + .output_tokens(Some(400)) |
| 166 | + .accumulated_total_tokens(Some(1000)) |
| 167 | + .accumulated_input_tokens(Some(600)) |
| 168 | + .accumulated_output_tokens(Some(400)) |
| 169 | + .apply() |
| 170 | + .await?; |
| 171 | + |
| 172 | + Ok(session) |
| 173 | +} |
| 174 | + |
| 175 | +/// Build the initial conversation: one user message + `n` tool pairs. |
| 176 | +fn build_conversation_with_tool_pairs(n: usize) -> Vec<Message> { |
| 177 | + let mut messages = vec![Message::user().with_text("list files").with_id("msg_user_0")]; |
| 178 | + for i in 1..=n { |
| 179 | + messages.extend(create_tool_pair( |
| 180 | + &format!("call_{i}"), |
| 181 | + &format!("resp_{i}"), |
| 182 | + "shell", |
| 183 | + &format!("output from tool call {i}"), |
| 184 | + )); |
| 185 | + } |
| 186 | + messages |
| 187 | +} |
| 188 | + |
| 189 | +// --------------------------------------------------------------------------- |
| 190 | +// Test 1: HistoryReplaced is emitted after tool-pair summarization |
| 191 | +// --------------------------------------------------------------------------- |
| 192 | + |
| 193 | +#[tokio::test] |
| 194 | +#[serial] |
| 195 | +async fn test_history_replaced_emitted_after_tool_pair_summarization() -> Result<()> { |
| 196 | + // cutoff=2 means summarization triggers when tool_call_count > 2. |
| 197 | + // We supply 3 tool pairs so the first one gets summarised. |
| 198 | + std::env::set_var("GOOSE_TOOL_CALL_CUTOFF", "2"); |
| 199 | + |
| 200 | + let temp_dir = TempDir::new()?; |
| 201 | + let agent = Agent::new(); |
| 202 | + |
| 203 | + let messages = build_conversation_with_tool_pairs(3); |
| 204 | + let session = |
| 205 | + setup_test_session(&agent, &temp_dir, "summarization-test", messages).await?; |
| 206 | + |
| 207 | + let mock_provider = Arc::new(MockSummarizationProvider::new()); |
| 208 | + agent.update_provider(mock_provider, &session.id).await?; |
| 209 | + |
| 210 | + let session_config = SessionConfig { |
| 211 | + id: session.id.clone(), |
| 212 | + schedule_id: None, |
| 213 | + max_turns: Some(1), |
| 214 | + retry_config: None, |
| 215 | + }; |
| 216 | + |
| 217 | + let new_user_message = Message::user() |
| 218 | + .with_text("continue") |
| 219 | + .with_id("msg_user_continue"); |
| 220 | + |
| 221 | + let reply_stream = agent.reply(new_user_message, session_config, None).await?; |
| 222 | + tokio::pin!(reply_stream); |
| 223 | + |
| 224 | + let mut history_replaced_events: Vec<Conversation> = Vec::new(); |
| 225 | + |
| 226 | + while let Some(event_result) = reply_stream.next().await { |
| 227 | + match event_result { |
| 228 | + Ok(AgentEvent::HistoryReplaced(conv)) => { |
| 229 | + history_replaced_events.push(conv); |
| 230 | + } |
| 231 | + Ok(_) => {} |
| 232 | + Err(e) => return Err(e), |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + // --- Assertions --- |
| 237 | + |
| 238 | + // 1. At least one HistoryReplaced event was emitted. |
| 239 | + assert!( |
| 240 | + !history_replaced_events.is_empty(), |
| 241 | + "Expected at least one HistoryReplaced event from tool-pair summarization" |
| 242 | + ); |
| 243 | + |
| 244 | + let final_conv = history_replaced_events.last().unwrap(); |
| 245 | + let msgs = final_conv.messages(); |
| 246 | + |
| 247 | + // 2. There should be a hidden summary message (agent-visible, user-invisible). |
| 248 | + let hidden_summaries: Vec<&Message> = msgs |
| 249 | + .iter() |
| 250 | + .filter(|m: &&Message| !m.is_user_visible() && m.is_agent_visible()) |
| 251 | + .collect(); |
| 252 | + assert!( |
| 253 | + !hidden_summaries.is_empty(), |
| 254 | + "Expected at least one hidden summary message in the conversation" |
| 255 | + ); |
| 256 | + |
| 257 | + // 3. The summary text should contain "shell" (from our mock response). |
| 258 | + let summary_text: String = hidden_summaries |
| 259 | + .iter() |
| 260 | + .flat_map(|m| m.content.iter()) |
| 261 | + .filter_map(|c| match c { |
| 262 | + MessageContent::Text(t) => Some(t.text.clone()), |
| 263 | + _ => None, |
| 264 | + }) |
| 265 | + .collect::<Vec<_>>() |
| 266 | + .join(" "); |
| 267 | + assert!( |
| 268 | + summary_text.contains("shell"), |
| 269 | + "Summary text should mention 'shell', got: {summary_text}" |
| 270 | + ); |
| 271 | + |
| 272 | + // 4. The original first tool pair should be marked agent-invisible. |
| 273 | + let agent_invisible_msgs: Vec<&Message> = msgs |
| 274 | + .iter() |
| 275 | + .filter(|m: &&Message| !m.is_agent_visible()) |
| 276 | + .collect(); |
| 277 | + assert!( |
| 278 | + agent_invisible_msgs.len() >= 2, |
| 279 | + "Expected the original tool pair (2 messages) to be marked agent-invisible, found {}", |
| 280 | + agent_invisible_msgs.len() |
| 281 | + ); |
| 282 | + |
| 283 | + std::env::remove_var("GOOSE_TOOL_CALL_CUTOFF"); |
| 284 | + Ok(()) |
| 285 | +} |
| 286 | + |
| 287 | +// --------------------------------------------------------------------------- |
| 288 | +// Test 2: Stale conversation_so_far overwrites hidden summaries |
| 289 | +// --------------------------------------------------------------------------- |
| 290 | + |
| 291 | +#[tokio::test] |
| 292 | +#[serial] |
| 293 | +async fn test_stale_conversation_overwrites_hidden_summary() -> Result<()> { |
| 294 | + std::env::set_var("GOOSE_TOOL_CALL_CUTOFF", "2"); |
| 295 | + |
| 296 | + let temp_dir = TempDir::new()?; |
| 297 | + let agent = Agent::new(); |
| 298 | + |
| 299 | + let messages = build_conversation_with_tool_pairs(3); |
| 300 | + let session = setup_test_session(&agent, &temp_dir, "desync-test", messages).await?; |
| 301 | + |
| 302 | + let mock_provider = Arc::new(MockSummarizationProvider::new()); |
| 303 | + agent.update_provider(mock_provider, &session.id).await?; |
| 304 | + |
| 305 | + let session_config = SessionConfig { |
| 306 | + id: session.id.clone(), |
| 307 | + schedule_id: None, |
| 308 | + max_turns: Some(1), |
| 309 | + retry_config: None, |
| 310 | + }; |
| 311 | + |
| 312 | + let new_user_message = Message::user() |
| 313 | + .with_text("continue") |
| 314 | + .with_id("msg_user_continue"); |
| 315 | + |
| 316 | + // Run the agent so tool-pair summarization fires. |
| 317 | + let reply_stream = agent.reply(new_user_message, session_config, None).await?; |
| 318 | + tokio::pin!(reply_stream); |
| 319 | + while let Some(event_result) = reply_stream.next().await { |
| 320 | + match event_result { |
| 321 | + Ok(_) => {} |
| 322 | + Err(e) => return Err(e), |
| 323 | + } |
| 324 | + } |
| 325 | + |
| 326 | + // --- Step 1: Read back server state and confirm hidden messages exist --- |
| 327 | + let server_session = agent |
| 328 | + .config |
| 329 | + .session_manager |
| 330 | + .get_session(&session.id, true) |
| 331 | + .await?; |
| 332 | + let server_conv = server_session.conversation.as_ref().unwrap(); |
| 333 | + let server_msgs = server_conv.messages(); |
| 334 | + |
| 335 | + let hidden_count_before = server_msgs |
| 336 | + .iter() |
| 337 | + .filter(|m: &&Message| !m.is_user_visible() && m.is_agent_visible()) |
| 338 | + .count(); |
| 339 | + assert!( |
| 340 | + hidden_count_before > 0, |
| 341 | + "Server should have at least one hidden summary after tool-pair summarization, found 0" |
| 342 | + ); |
| 343 | + |
| 344 | + // --- Step 2: Simulate stale UI — keep only user-visible messages --- |
| 345 | + let stale_messages: Vec<Message> = server_msgs |
| 346 | + .iter() |
| 347 | + .filter(|m: &&Message| m.is_user_visible()) |
| 348 | + .cloned() |
| 349 | + .collect(); |
| 350 | + |
| 351 | + let stale_conv = Conversation::new_unvalidated(stale_messages); |
| 352 | + agent |
| 353 | + .config |
| 354 | + .session_manager |
| 355 | + .replace_conversation(&session.id, &stale_conv) |
| 356 | + .await?; |
| 357 | + |
| 358 | + // --- Step 3: Read back and verify hidden summaries were wiped --- |
| 359 | + let after_session = agent |
| 360 | + .config |
| 361 | + .session_manager |
| 362 | + .get_session(&session.id, true) |
| 363 | + .await?; |
| 364 | + let after_conv = after_session.conversation.as_ref().unwrap(); |
| 365 | + let after_msgs = after_conv.messages(); |
| 366 | + |
| 367 | + let hidden_count_after = after_msgs |
| 368 | + .iter() |
| 369 | + .filter(|m: &&Message| !m.is_user_visible() && m.is_agent_visible()) |
| 370 | + .count(); |
| 371 | + assert_eq!( |
| 372 | + hidden_count_after, 0, |
| 373 | + "After replacing with stale (user-visible only) conversation, \ |
| 374 | + hidden summaries should be gone, but found {hidden_count_after}" |
| 375 | + ); |
| 376 | + |
| 377 | + std::env::remove_var("GOOSE_TOOL_CALL_CUTOFF"); |
| 378 | + Ok(()) |
| 379 | +} |
0 commit comments