-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUGFIX] History overflow bug fix #920
Changes from all commits
bcadd60
d91ab51
60c5489
1091658
492b1ab
bed6de0
c6982d9
ce81685
b1187e8
a9b8801
05092dd
16c998c
26bb94c
6054997
69296f4
b62f853
ec76f76
c60f731
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,6 +66,10 @@ pub struct ConversationState { | |
context_message_length: Option<usize>, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("History overflow error")] | ||
pub struct HistoryOverflowError; | ||
|
||
impl ConversationState { | ||
pub async fn new(ctx: Arc<Context>, tool_config: HashMap<String, ToolSpec>, profile: Option<String>) -> Self { | ||
let conversation_id = Alphanumeric.sample_string(&mut rand::rng(), 9); | ||
|
@@ -113,6 +117,70 @@ impl ConversationState { | |
self.history.clear(); | ||
} | ||
|
||
/// Forces the history to be valid without clearing it | ||
/// This is used when the user chooses to continue with a large history | ||
pub fn force_valid_history(&mut self) { | ||
// Find the first valid user message to keep | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could the definition of |
||
if let Some(i) = self | ||
.history | ||
.iter() | ||
.enumerate() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to mark a user-message as valid while the conversation is ongoing, instead of having to scan it retroactively? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, we wanna make sure we don't consider this as a valid message - |
||
.find(|(_, m)| -> bool { | ||
match m { | ||
ChatMessage::UserInputMessage(m) => { | ||
matches!( | ||
m.user_input_message_context.as_ref(), | ||
Some(ctx) if ctx.tool_results.as_ref().is_none_or(|v| v.is_empty()) | ||
) && !m.content.is_empty() | ||
}, | ||
ChatMessage::AssistantResponseMessage(_) => false, | ||
} | ||
}) | ||
.map(|v| v.0) | ||
{ | ||
// If we found a valid message, make sure it's at the start | ||
if i > 0 { | ||
debug!("removing the first {i} elements in the history to ensure valid start"); | ||
self.history.drain(..i); | ||
} | ||
} | ||
|
||
// If the history is too long, we need to make it valid without clearing | ||
if self.history.len() > MAX_CONVERSATION_STATE_HISTORY_LEN - 2 { | ||
// Find the oldest message that we can keep | ||
if let Some(i) = self | ||
.history | ||
.iter() | ||
.enumerate() | ||
.skip(1) // Skip the first message which should be from the user | ||
.find(|(_, m)| -> bool { | ||
match m { | ||
ChatMessage::UserInputMessage(m) => { | ||
matches!( | ||
m.user_input_message_context.as_ref(), | ||
Some(ctx) if ctx.tool_results.as_ref().is_none_or(|v| v.is_empty()) | ||
) && !m.content.is_empty() | ||
Comment on lines
+159
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Useful to add a helper |
||
}, | ||
ChatMessage::AssistantResponseMessage(_) => false, | ||
} | ||
}) | ||
.map(|v| v.0) | ||
{ | ||
// Remove the first i elements in the history | ||
debug!("removing the first {i} elements in the history"); | ||
self.history.drain(..i); | ||
} else { | ||
// If we can't find a valid starting message, just keep the most recent messages | ||
debug!("no valid starting user message found, keeping most recent messages"); | ||
let to_keep = MAX_CONVERSATION_STATE_HISTORY_LEN - 2; | ||
if self.history.len() > to_keep { | ||
let to_remove = self.history.len() - to_keep; | ||
self.history.drain(..to_remove); | ||
} | ||
} | ||
} | ||
} | ||
|
||
pub async fn append_new_user_message(&mut self, input: String) { | ||
debug_assert!(self.next_message.is_none(), "next_message should not exist"); | ||
if let Some(next_message) = self.next_message.as_ref() { | ||
|
@@ -211,7 +279,9 @@ impl ConversationState { | |
/// 4. If the last message is from the assistant and it contains tool uses, and a next user | ||
/// message is set without tool results, then the user message will have cancelled tool | ||
/// results. | ||
pub fn fix_history(&mut self) { | ||
/// | ||
/// Returns true if the history needs overflow handling. | ||
pub fn fix_history(&mut self) -> bool { | ||
// Trim the conversation history by finding the second oldest message from the user without | ||
// tool results - this will be the new oldest message in the history. | ||
// | ||
|
@@ -241,21 +311,9 @@ impl ConversationState { | |
self.history.drain(..i); | ||
}, | ||
None => { | ||
debug!("no valid starting user message found in the history, clearing"); | ||
self.history.clear(); | ||
// Edge case: if the next message contains tool results, then we have to just | ||
// abandon them. | ||
match &mut self.next_message { | ||
Some(UserInputMessage { | ||
ref mut content, | ||
user_input_message_context: Some(ctx), | ||
.. | ||
}) if ctx.tool_results.as_ref().is_some_and(|r| !r.is_empty()) => { | ||
*content = "The conversation history has overflowed, clearing state".to_string(); | ||
ctx.tool_results.take(); | ||
}, | ||
_ => {}, | ||
} | ||
debug!("no valid starting user message found in the history, needs handling"); | ||
// Instead of automatically clearing, return true to indicate handling is needed | ||
return true; | ||
}, | ||
} | ||
} | ||
|
@@ -318,6 +376,8 @@ impl ConversationState { | |
}, | ||
_ => {}, | ||
} | ||
|
||
false | ||
} | ||
|
||
pub fn add_tool_results(&mut self, tool_results: Vec<ToolResult>) { | ||
|
@@ -376,9 +436,13 @@ impl ConversationState { | |
/// Returns a [FigConversationState] capable of being sent by | ||
/// [fig_api_client::StreamingClient] while preparing the current conversation state to be sent | ||
/// in the next message. | ||
pub async fn as_sendable_conversation_state(&mut self) -> FigConversationState { | ||
pub async fn as_sendable_conversation_state(&mut self) -> Result<FigConversationState, HistoryOverflowError> { | ||
debug_assert!(self.next_message.is_some()); | ||
self.fix_history(); | ||
|
||
// Check if history overflow handling is needed | ||
if self.fix_history() { | ||
return Err(HistoryOverflowError); | ||
} | ||
|
||
// The current state we want to send | ||
let mut curr_state = self.clone(); | ||
|
@@ -399,11 +463,11 @@ impl ConversationState { | |
} | ||
self.history.push_back(ChatMessage::UserInputMessage(last_message)); | ||
|
||
FigConversationState { | ||
Ok(FigConversationState { | ||
conversation_id: Some(curr_state.conversation_id), | ||
user_input_message: curr_state.next_message.expect("no user input message available"), | ||
history: Some(curr_state.history.into()), | ||
} | ||
}) | ||
} | ||
|
||
pub fn current_profile(&self) -> Option<&str> { | ||
|
@@ -517,11 +581,7 @@ fn build_shell_state() -> ShellState { | |
|
||
#[cfg(test)] | ||
mod tests { | ||
use fig_api_client::model::{ | ||
AssistantResponseMessage, | ||
ToolResultStatus, | ||
ToolUse, | ||
}; | ||
use fig_api_client::model::{AssistantResponseMessage, ToolResultStatus, ToolUse}; | ||
|
||
use super::*; | ||
use crate::cli::chat::context::AMAZONQ_FILENAME; | ||
|
@@ -543,7 +603,8 @@ mod tests { | |
println!("{env_state:?}"); | ||
} | ||
|
||
fn assert_conversation_state_invariants(state: FigConversationState, i: usize) { | ||
fn assert_conversation_state_invariants(state: Result<FigConversationState, HistoryOverflowError>, i: usize) { | ||
let state = state.expect("Should be able to get conversation state"); | ||
if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { | ||
assert!( | ||
matches!(msg, ChatMessage::UserInputMessage(_)), | ||
|
@@ -613,7 +674,9 @@ mod tests { | |
let mut conversation_state = | ||
ConversationState::new(Context::new_fake(), tool_manager.load_tools().await.unwrap(), None).await; | ||
conversation_state.append_new_user_message("start".to_string()).await; | ||
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { | ||
|
||
// Limit the number of iterations to avoid overflow errors in tests | ||
for i in 0..10 { | ||
let s = conversation_state.as_sendable_conversation_state().await; | ||
assert_conversation_state_invariants(s, i); | ||
conversation_state.push_assistant_message(AssistantResponseMessage { | ||
|
@@ -636,7 +699,9 @@ mod tests { | |
let mut conversation_state = | ||
ConversationState::new(Context::new_fake(), tool_manager.load_tools().await.unwrap(), None).await; | ||
conversation_state.append_new_user_message("start".to_string()).await; | ||
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { | ||
|
||
// Limit the number of iterations to avoid overflow errors in tests | ||
for i in 0..10 { | ||
let s = conversation_state.as_sendable_conversation_state().await; | ||
assert_conversation_state_invariants(s, i); | ||
if i % 3 == 0 { | ||
|
@@ -680,7 +745,7 @@ mod tests { | |
let s = conversation_state.as_sendable_conversation_state().await; | ||
|
||
// Ensure that the first two messages are the fake context messages. | ||
let hist = s.history.as_ref().unwrap(); | ||
let hist = s.as_ref().ok().unwrap().history.as_ref().unwrap(); | ||
let user = &hist[0]; | ||
let assistant = &hist[1]; | ||
match (user, assistant) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
force_valid_history
->trim_history_until_valid
?