-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Expand file tree
/
Copy pathmoim.rs
More file actions
153 lines (130 loc) · 5.5 KB
/
moim.rs
File metadata and controls
153 lines (130 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use crate::agents::extension_manager::ExtensionManager;
use crate::conversation::message::Message;
use crate::conversation::{fix_conversation, Conversation};
use rmcp::model::Role;
use std::path::Path;
// Test-only utility. Do not use in production code. No `test` directive due to call outside crate.
thread_local! {
pub static SKIP: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
}
pub async fn inject_moim(
session_id: &str,
conversation: Conversation,
extension_manager: &ExtensionManager,
working_dir: &Path,
) -> Conversation {
if SKIP.with(|f| f.get()) {
return conversation;
}
if let Some(moim) = extension_manager
.collect_moim(session_id, working_dir)
.await
{
let mut messages = conversation.messages().clone();
let idx = messages
.iter()
.rposition(|m| m.role == Role::Assistant)
.unwrap_or(0);
messages.insert(idx, Message::user().with_text(moim));
let (fixed, issues) = fix_conversation(Conversation::new_unvalidated(messages));
let has_unexpected_issues = issues.iter().any(|issue| {
!issue.contains("Merged consecutive user messages")
&& !issue.contains("Merged consecutive assistant messages")
&& !issue.contains("Merged text content")
&& !issue.contains("Removed orphaned tool response")
&& !issue.contains("Removed orphaned tool request")
&& !issue.contains("Removed empty message")
&& !issue.contains("Removed leading assistant message")
&& !issue.contains("Removed trailing assistant message")
});
if has_unexpected_issues {
tracing::warn!("MOIM injection caused unexpected issues: {:?}", issues);
return conversation;
}
if !issues.is_empty() {
tracing::info!("MOIM injection applied conversation fixes: {:?}", issues);
}
return fixed;
}
conversation
}
#[cfg(test)]
mod tests {
use super::*;
use rmcp::model::CallToolRequestParams;
use std::path::PathBuf;
#[tokio::test]
async fn test_moim_injection_before_assistant() {
let temp_dir = tempfile::tempdir().unwrap();
let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());
let working_dir = PathBuf::from("/test/dir");
let conv = Conversation::new_unvalidated(vec![
Message::user().with_text("Hello"),
Message::assistant().with_text("Hi"),
Message::user().with_text("Bye"),
]);
let result = inject_moim("test-session-id", conv, &em, &working_dir).await;
let msgs = result.messages();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].content[0].as_text().unwrap(), "Hello");
assert_eq!(msgs[1].content[0].as_text().unwrap(), "Hi");
let merged_content = msgs[0]
.content
.iter()
.filter_map(|c| c.as_text())
.collect::<Vec<_>>()
.join("");
assert!(merged_content.contains("Hello"));
assert!(merged_content.contains("<info-msg>"));
assert!(merged_content.contains("Working directory: /test/dir"));
}
#[tokio::test]
async fn test_moim_injection_no_assistant() {
let temp_dir = tempfile::tempdir().unwrap();
let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());
let working_dir = PathBuf::from("/test/dir");
let conv = Conversation::new_unvalidated(vec![Message::user().with_text("Hello")]);
let result = inject_moim("test-session-id", conv, &em, &working_dir).await;
assert_eq!(result.messages().len(), 1);
let merged_content = result.messages()[0]
.content
.iter()
.filter_map(|c| c.as_text())
.collect::<Vec<_>>()
.join("");
assert!(merged_content.contains("Hello"));
assert!(merged_content.contains("<info-msg>"));
assert!(merged_content.contains("Working directory: /test/dir"));
}
#[tokio::test]
async fn test_moim_with_tool_calls() {
let temp_dir = tempfile::tempdir().unwrap();
let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());
let working_dir = PathBuf::from("/test/dir");
let conv = Conversation::new_unvalidated(vec![
Message::user().with_text("Search for something"),
Message::assistant()
.with_text("I'll search for you")
.with_tool_request("search_1", Ok(CallToolRequestParams::new("search"))),
Message::user()
.with_tool_response("search_1", Ok(rmcp::model::CallToolResult::success(vec![]))),
Message::assistant()
.with_text("I need to search more")
.with_tool_request("search_2", Ok(CallToolRequestParams::new("search"))),
Message::user()
.with_tool_response("search_2", Ok(rmcp::model::CallToolResult::success(vec![]))),
]);
let result = inject_moim("test-session-id", conv, &em, &working_dir).await;
let msgs = result.messages();
assert_eq!(msgs.len(), 6);
let moim_msg = &msgs[3];
let has_moim = moim_msg
.content
.iter()
.any(|c| c.as_text().is_some_and(|t| t.contains("<info-msg>")));
assert!(
has_moim,
"MOIM should be in message before latest assistant message"
);
}
}