-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Cronus42/main #7182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cronus42/main #7182
Changes from all commits
9c0f1cd
cfa8bc0
844e0e6
5fdf6ac
1712df1
821cda3
c8c3129
a990adf
21e4247
b5f92b1
63d44f8
5e72c3f
b020cb2
1d91c69
daf7b95
6a08f0c
db3d537
f4606c9
618e283
53fdf87
33da6ef
ef655db
eb58959
826d961
8a382f5
5f32a0b
73589e3
cf9574d
c36cbed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| //! One-time migration script to consolidate fragmented assistant messages | ||
| //! | ||
| //! This script fixes chat histories that were broken up during streaming before | ||
| //! the consolidation fix was implemented. | ||
| //! | ||
| //! Usage: | ||
| //! cargo run --bin consolidate-messages | ||
|
|
||
| use anyhow::Result; | ||
| use goose::session::session_manager::SessionManager; | ||
|
|
||
| #[tokio::main] | ||
| async fn main() -> Result<()> { | ||
| println!("🔧 Consolidating Fragmented Messages"); | ||
| println!("====================================="); | ||
| println!(); | ||
| println!("This will merge consecutive assistant text messages that were"); | ||
| println!("fragmented during streaming. This operation is safe and can be"); | ||
| println!("run multiple times."); | ||
| println!(); | ||
|
|
||
| print!("Scanning database... "); | ||
| let count = SessionManager::consolidate_fragmented_messages().await?; | ||
| println!("done!"); | ||
| println!(); | ||
|
|
||
| if count == 0 { | ||
| println!("✅ No fragmented messages found - your database is already clean!"); | ||
| } else { | ||
| println!("✅ Successfully consolidated {} message fragments", count); | ||
| println!(" Your chat history should now display correctly!"); | ||
| } | ||
|
|
||
| println!(); | ||
| println!("🎉 Migration complete!"); | ||
| println!(); | ||
| Ok(()) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,158 @@ use serde_json::Value; | |
| use super::super::base::Usage; | ||
| use crate::conversation::message::{Message, MessageContent}; | ||
|
|
||
| /// Accumulates streaming chunks into a complete message | ||
| #[derive(Debug, Default)] | ||
| pub struct BedrockStreamAccumulator { | ||
| text_blocks: HashMap<i32, String>, | ||
| text_block_emitted_char_counts: HashMap<i32, usize>, | ||
| tool_blocks: HashMap<i32, (String, String, String)>, | ||
| role: Option<Role>, | ||
| usage: Option<bedrock::TokenUsage>, | ||
| } | ||
|
|
||
| impl BedrockStreamAccumulator { | ||
| pub fn new() -> Self { | ||
| Self::default() | ||
| } | ||
|
|
||
| pub fn handle_message_start(&mut self, role: &bedrock::ConversationRole) -> Result<()> { | ||
| self.role = Some(from_bedrock_role(role)?); | ||
| Ok(()) | ||
| } | ||
|
|
||
| pub fn handle_content_block_start( | ||
| &mut self, | ||
| index: i32, | ||
| start: &bedrock::ContentBlockStart, | ||
| ) -> Result<()> { | ||
| match start { | ||
| bedrock::ContentBlockStart::ToolUse(tool_use) => { | ||
| let tool_use_id = tool_use.tool_use_id().to_string(); | ||
| let name = tool_use.name().to_string(); | ||
| self.tool_blocks | ||
| .insert(index, (tool_use_id, name, String::new())); | ||
| } | ||
| _ => { | ||
| self.text_blocks.insert(index, String::new()); | ||
| self.text_block_emitted_char_counts.insert(index, 0); | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| pub fn handle_content_block_delta( | ||
| &mut self, | ||
| index: i32, | ||
| delta: &bedrock::ContentBlockDelta, | ||
| ) -> Result<Option<Message>> { | ||
| match delta { | ||
| bedrock::ContentBlockDelta::Text(text) => { | ||
| // Ensure the block exists (in case we get delta before start) | ||
| self.text_blocks.entry(index).or_default().push_str(text); | ||
| self.build_incremental_delta_message(index) | ||
| } | ||
| bedrock::ContentBlockDelta::ToolUse(tool_delta) => { | ||
| if let Some((_, _, json)) = self.tool_blocks.get_mut(&index) { | ||
| json.push_str(&tool_delta.input); | ||
| } | ||
| Ok(None) | ||
| } | ||
| _ => Ok(None), | ||
| } | ||
| } | ||
|
|
||
| pub fn handle_message_stop( | ||
| &mut self, | ||
| _stop_reason: bedrock::StopReason, | ||
| ) -> Result<Option<Message>> { | ||
| self.build_final_message() | ||
| } | ||
|
|
||
| pub fn handle_metadata(&mut self, usage: Option<bedrock::TokenUsage>) { | ||
| if let Some(u) = usage { | ||
| self.usage = Some(u); | ||
| } | ||
| } | ||
|
|
||
| /// Build a message with only the new text delta for streaming | ||
| fn build_incremental_delta_message(&mut self, index: i32) -> Result<Option<Message>> { | ||
| if let Some(text) = self.text_blocks.get(&index) { | ||
| let emitted_char_count = *self | ||
| .text_block_emitted_char_counts | ||
| .get(&index) | ||
| .unwrap_or(&0); | ||
| let current_char_count = text.chars().count(); | ||
|
|
||
| if current_char_count > emitted_char_count { | ||
| let delta = text.chars().skip(emitted_char_count).collect::<String>(); | ||
| self.text_block_emitted_char_counts | ||
| .insert(index, current_char_count); | ||
|
|
||
| let role = self.role.clone().unwrap_or(Role::Assistant); | ||
| let created = Utc::now().timestamp(); | ||
| let content = vec![MessageContent::text(delta)]; | ||
|
|
||
| return Ok(Some(Message::new(role, created, content))); | ||
|
Comment on lines
+95
to
+112
|
||
| } | ||
| } | ||
| Ok(None) | ||
| } | ||
|
|
||
| fn build_final_message(&self) -> Result<Option<Message>> { | ||
| let role = self.role.clone().unwrap_or(Role::Assistant); | ||
| let created = Utc::now().timestamp(); | ||
|
||
| let mut content = Vec::new(); | ||
|
|
||
| // Only include text blocks that have remaining content not yet emitted during streaming | ||
| let mut indices: Vec<_> = self.text_blocks.keys().cloned().collect(); | ||
| indices.sort(); | ||
| for idx in indices { | ||
| if let Some(text) = self.text_blocks.get(&idx) { | ||
| let emitted_char_count = | ||
| *self.text_block_emitted_char_counts.get(&idx).unwrap_or(&0); | ||
| let current_char_count = text.chars().count(); | ||
| if current_char_count > emitted_char_count { | ||
| let remaining = text.chars().skip(emitted_char_count).collect::<String>(); | ||
| if !remaining.is_empty() { | ||
| content.push(MessageContent::text(remaining)); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Tool blocks are always included as they are only complete at the end of streaming | ||
| let mut tool_indices: Vec<_> = self.tool_blocks.keys().cloned().collect(); | ||
| tool_indices.sort(); | ||
| for idx in tool_indices { | ||
| if let Some((tool_use_id, name, json)) = self.tool_blocks.get(&idx) { | ||
| if let Ok(args) = serde_json::from_str::<serde_json::Value>(json) { | ||
| let tool_call = CallToolRequestParams { | ||
| meta: None, | ||
| task: None, | ||
| name: name.clone().into(), | ||
| arguments: args.as_object().cloned(), | ||
| }; | ||
| content.push(MessageContent::tool_request( | ||
| tool_use_id.clone(), | ||
| Ok(tool_call), | ||
| )); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if content.is_empty() { | ||
| Ok(None) | ||
| } else { | ||
| Ok(Some(Message::new(role, created, content))) | ||
| } | ||
| } | ||
|
|
||
| pub fn get_usage(&self) -> Option<Usage> { | ||
| self.usage.as_ref().map(from_bedrock_usage) | ||
| } | ||
| } | ||
|
|
||
| pub fn to_bedrock_message(message: &Message) -> Result<bedrock::Message> { | ||
| bedrock::Message::builder() | ||
| .role(to_bedrock_role(&message.role)) | ||
|
|
@@ -43,14 +195,8 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C | |
| MessageContent::Image(image) => { | ||
| bedrock::ContentBlock::Image(to_bedrock_image(&image.data, &image.mime_type)?) | ||
| } | ||
| MessageContent::Thinking(_) => { | ||
| // Thinking blocks are not supported in Bedrock - skip | ||
| bedrock::ContentBlock::Text("".to_string()) | ||
| } | ||
| MessageContent::RedactedThinking(_) => { | ||
| // Redacted thinking blocks are not supported in Bedrock - skip | ||
| bedrock::ContentBlock::Text("".to_string()) | ||
| } | ||
| MessageContent::Thinking(_) => bedrock::ContentBlock::Text("".to_string()), | ||
| MessageContent::RedactedThinking(_) => bedrock::ContentBlock::Text("".to_string()), | ||
| MessageContent::SystemNotification(_) => { | ||
| bail!("SystemNotification should not get passed to the provider") | ||
| } | ||
|
|
@@ -93,13 +239,10 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C | |
| .map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c.clone())) | ||
| .collect::<Result<_>>()?, | ||
| ), | ||
| Err(error) => { | ||
| // For errors, create a text content block with the error message | ||
| Some(vec![bedrock::ToolResultContentBlock::Text(format!( | ||
| "The tool call returned the following error:\n{}", | ||
| error | ||
| ))]) | ||
| } | ||
| Err(error) => Some(vec![bedrock::ToolResultContentBlock::Text(format!( | ||
| "The tool call returned the following error:\n{}", | ||
| error | ||
| ))]), | ||
| }; | ||
| bedrock::ContentBlock::ToolResult( | ||
| bedrock::ToolResultBlock::builder() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ use std::sync::Arc; | |
| use tokio::sync::Mutex; | ||
|
|
||
| use super::base::{ | ||
| LeadWorkerProviderTrait, Provider, ProviderDef, ProviderMetadata, ProviderUsage, | ||
| LeadWorkerProviderTrait, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, | ||
| }; | ||
| use super::errors::ProviderError; | ||
| use crate::conversation::message::{Message, MessageContent}; | ||
|
|
@@ -484,6 +484,41 @@ impl Provider for LeadWorkerProvider { | |
| } | ||
| } | ||
|
|
||
| async fn stream( | ||
| &self, | ||
| session_id: &str, | ||
| system: &str, | ||
| messages: &[Message], | ||
| tools: &[Tool], | ||
| ) -> Result<MessageStream, ProviderError> { | ||
| // Prefer the current active provider if it supports streaming, otherwise | ||
| // fall back to the other provider that does. | ||
| let count = *self.turn_count.lock().await; | ||
| let in_fallback = *self.in_fallback_mode.lock().await; | ||
|
|
||
| let (primary, secondary) = if count < self.lead_turns || in_fallback { | ||
| (&self.lead_provider, &self.worker_provider) | ||
| } else { | ||
| (&self.worker_provider, &self.lead_provider) | ||
| }; | ||
|
|
||
| if primary.supports_streaming() { | ||
| return primary.stream(session_id, system, messages, tools).await; | ||
| } | ||
|
|
||
| if secondary.supports_streaming() { | ||
| return secondary.stream(session_id, system, messages, tools).await; | ||
| } | ||
|
|
||
| Err(ProviderError::NotImplemented( | ||
| "streaming not implemented for lead/worker configuration".to_string(), | ||
| )) | ||
| } | ||
|
Comment on lines
+487
to
+516
|
||
|
|
||
| fn supports_streaming(&self) -> bool { | ||
| self.lead_provider.supports_streaming() || self.worker_provider.supports_streaming() | ||
| } | ||
|
|
||
| /// Check if this provider is a LeadWorkerProvider | ||
| fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> { | ||
| Some(self) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The timestamp is generated using
Utc::now().timestamp()which returns seconds since epoch, but other parts of the codebase use milliseconds. This inconsistency could cause issues with message ordering and display. Consider usingUtc::now().timestamp_millis()instead to match the timestamp format used elsewhere in the codebase (e.g., in the session_manager tests wherechrono::Utc::now().timestamp_millis()is used).