Skip to content

Commit 3369f5f

Browse files
authored
Merge pull request #506 from stakpak/fix/stakpak-expand-merged-tool-results
fix: expand merged tool results in Stakpak request converter
2 parents c213f1b + 59d425a commit 3369f5f

5 files changed

Lines changed: 717 additions & 49 deletions

File tree

cli/src/commands/agent/run/mode_interactive.rs

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -536,43 +536,54 @@ pub async fn run_interactive(
536536
));
537537
}
538538
if !is_cancelled {
539-
let content_parts: Vec<String> = result
540-
.content
541-
.iter()
542-
.map(|c| match c.raw.as_text() {
543-
Some(text) => text.text.clone(),
544-
None => String::new(),
545-
})
546-
.filter(|s| !s.is_empty())
547-
.collect();
548-
549-
let status = result.get_status();
550-
let result_content = if status == ToolCallResultStatus::Error
551-
&& content_parts.len() >= 2
552-
{
553-
// For error cases, preserve the original formatting
554-
let error_message = content_parts[1..].join(": ");
555-
format!("[{}] {}", content_parts[0], error_message)
539+
// If a CANCELLED result was already inserted for this tool_call
540+
// (e.g., user sent a message while the tool was in-flight),
541+
// skip adding the real result to avoid duplicate tool_call_ids.
542+
let already_resolved = messages.iter().any(|m| {
543+
m.role == Role::Tool
544+
&& m.tool_call_id.as_deref() == Some(&tool_call.id)
545+
});
546+
if already_resolved {
547+
// Skip — a CANCELLED placeholder was already inserted
556548
} else {
557-
content_parts.join("\n")
558-
};
549+
let content_parts: Vec<String> = result
550+
.content
551+
.iter()
552+
.map(|c| match c.raw.as_text() {
553+
Some(text) => text.text.clone(),
554+
None => String::new(),
555+
})
556+
.filter(|s| !s.is_empty())
557+
.collect();
558+
559+
let status = result.get_status();
560+
let result_content = if status == ToolCallResultStatus::Error
561+
&& content_parts.len() >= 2
562+
{
563+
// For error cases, preserve the original formatting
564+
let error_message = content_parts[1..].join(": ");
565+
format!("[{}] {}", content_parts[0], error_message)
566+
} else {
567+
content_parts.join("\n")
568+
};
559569

560-
messages.push(tool_result(
561-
tool_call.clone().id,
562-
result_content.clone(),
563-
));
570+
messages.push(tool_result(
571+
tool_call.clone().id,
572+
result_content.clone(),
573+
));
564574

565-
send_input_event(
566-
&input_tx,
567-
InputEvent::ToolResult(
568-
stakpak_shared::models::integrations::openai::ToolCallResult {
569-
call: tool_call.clone(),
570-
result: result_content,
571-
status,
572-
},
573-
),
574-
)
575-
.await?;
575+
send_input_event(
576+
&input_tx,
577+
InputEvent::ToolResult(
578+
stakpak_shared::models::integrations::openai::ToolCallResult {
579+
call: tool_call.clone(),
580+
result: result_content,
581+
status,
582+
},
583+
),
584+
)
585+
.await?;
586+
}
576587
}
577588
send_input_event(
578589
&input_tx,

libs/ai/src/providers/anthropic/convert.rs

Lines changed: 254 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,28 +300,86 @@ fn build_tools_with_caching(
300300
///
301301
/// Caches the last N non-system messages to maximize cache hits
302302
/// on subsequent requests in a conversation.
303+
///
304+
/// Auto-caching is applied **after** merging consecutive same-role messages
305+
/// to ensure breakpoints land on the actual final message positions. This
306+
/// prevents wasting breakpoints on intermediate blocks that get folded into
307+
/// a single merged message (e.g., multiple tool_result user messages).
303308
fn build_messages_with_caching(
304309
messages: &[Message],
305310
validator: &mut CacheControlValidator,
306311
tail_count: usize,
307312
) -> Result<Vec<AnthropicMessage>> {
308313
let non_system: Vec<&Message> = messages.iter().filter(|m| m.role != Role::System).collect();
309314

310-
let len = non_system.len();
311-
let cache_start_index = len.saturating_sub(tail_count);
312-
313-
// Phase 1: Convert each message individually
315+
// Phase 1: Convert each message individually (no auto-caching yet)
314316
let converted: Vec<AnthropicMessage> = non_system
315317
.iter()
316-
.enumerate()
317-
.map(|(i, msg)| {
318-
let should_auto_cache = tail_count > 0 && i >= cache_start_index;
319-
to_anthropic_message_with_caching(msg, validator, should_auto_cache)
320-
})
318+
.map(|msg| to_anthropic_message_with_caching(msg, validator, false))
321319
.collect::<Result<Vec<_>>>()?;
322320

323321
// Phase 2: Merge consecutive same-role messages
324-
Ok(merge_consecutive_messages(converted))
322+
let mut merged = merge_consecutive_messages(converted);
323+
324+
// Phase 3: Apply tail caching to the last N *merged* messages.
325+
// This ensures breakpoints are placed on the actual final message
326+
// boundaries after merging, not on pre-merge positions that may
327+
// end up as intermediate blocks inside a merged message.
328+
if tail_count > 0 {
329+
let len = merged.len();
330+
let cache_start = len.saturating_sub(tail_count);
331+
for msg in &mut merged[cache_start..] {
332+
apply_tail_cache_to_message(msg, validator);
333+
}
334+
}
335+
336+
Ok(merged)
337+
}
338+
339+
/// Apply ephemeral cache control to the last content block of a message.
340+
///
341+
/// Used for tail-caching after message merging to ensure cache breakpoints
342+
/// land on the actual last block of each merged message.
343+
fn apply_tail_cache_to_message(msg: &mut AnthropicMessage, validator: &mut CacheControlValidator) {
344+
let cache = crate::types::CacheControl::ephemeral();
345+
let context = if msg.role == "assistant" {
346+
CacheContext::assistant_message_part()
347+
} else {
348+
CacheContext::user_message_part()
349+
};
350+
351+
let Some(validated_cache) = validator.validate(Some(&cache), context) else {
352+
return; // Breakpoint limit exceeded
353+
};
354+
355+
let anthropic_cc = AnthropicCacheControl::from(&validated_cache);
356+
match &mut msg.content {
357+
AnthropicMessageContent::Blocks(blocks) => {
358+
if let Some(last) = blocks.last_mut() {
359+
set_block_cache_control(last, Some(anthropic_cc));
360+
}
361+
}
362+
AnthropicMessageContent::String(s) => {
363+
// Convert to blocks format to attach cache control
364+
msg.content = AnthropicMessageContent::Blocks(vec![AnthropicContent::Text {
365+
text: std::mem::take(s),
366+
cache_control: Some(anthropic_cc),
367+
}]);
368+
}
369+
}
370+
}
371+
372+
/// Set cache_control on an AnthropicContent block.
373+
fn set_block_cache_control(block: &mut AnthropicContent, cc: Option<AnthropicCacheControl>) {
374+
match block {
375+
AnthropicContent::Text { cache_control, .. }
376+
| AnthropicContent::ToolUse { cache_control, .. }
377+
| AnthropicContent::ToolResult { cache_control, .. }
378+
| AnthropicContent::Image { cache_control, .. } => *cache_control = cc,
379+
AnthropicContent::Thinking { .. } | AnthropicContent::RedactedThinking { .. } => {
380+
// Thinking blocks don't support cache_control
381+
}
382+
}
325383
}
326384

327385
/// Merge consecutive messages with the same role into single messages.
@@ -1094,4 +1152,190 @@ mod tests {
10941152
assert!(matches!(&blocks[2], AnthropicContent::ToolResult { .. }));
10951153
}
10961154
}
1155+
1156+
// --- apply_tail_cache_to_message tests ---
1157+
1158+
#[test]
1159+
fn test_apply_tail_cache_to_string_message() {
1160+
let mut validator = CacheControlValidator::new();
1161+
let mut msg = user_msg("hello");
1162+
1163+
apply_tail_cache_to_message(&mut msg, &mut validator);
1164+
1165+
// Should convert to blocks and add cache_control
1166+
match &msg.content {
1167+
AnthropicMessageContent::Blocks(blocks) => {
1168+
assert_eq!(blocks.len(), 1);
1169+
match &blocks[0] {
1170+
AnthropicContent::Text {
1171+
text,
1172+
cache_control,
1173+
} => {
1174+
assert_eq!(text, "hello");
1175+
assert!(cache_control.is_some());
1176+
}
1177+
_ => panic!("Expected Text block"),
1178+
}
1179+
}
1180+
_ => panic!("Expected Blocks content after tail cache"),
1181+
}
1182+
assert_eq!(validator.breakpoint_count(), 1);
1183+
}
1184+
1185+
#[test]
1186+
fn test_apply_tail_cache_to_blocks_message() {
1187+
let mut validator = CacheControlValidator::new();
1188+
let mut msg = user_blocks_msg(vec![
1189+
tool_result_block("t1", "result1"),
1190+
tool_result_block("t2", "result2"),
1191+
]);
1192+
1193+
apply_tail_cache_to_message(&mut msg, &mut validator);
1194+
1195+
// Only the LAST block should get cache_control
1196+
if let AnthropicMessageContent::Blocks(blocks) = &msg.content {
1197+
assert_eq!(blocks.len(), 2);
1198+
match &blocks[0] {
1199+
AnthropicContent::ToolResult { cache_control, .. } => {
1200+
assert!(cache_control.is_none(), "First block should NOT be cached");
1201+
}
1202+
_ => panic!("Expected ToolResult"),
1203+
}
1204+
match &blocks[1] {
1205+
AnthropicContent::ToolResult { cache_control, .. } => {
1206+
assert!(cache_control.is_some(), "Last block SHOULD be cached");
1207+
}
1208+
_ => panic!("Expected ToolResult"),
1209+
}
1210+
} else {
1211+
panic!("Expected Blocks content");
1212+
}
1213+
assert_eq!(validator.breakpoint_count(), 1);
1214+
}
1215+
1216+
#[test]
1217+
fn test_apply_tail_cache_respects_breakpoint_limit() {
1218+
let mut validator = CacheControlValidator::new();
1219+
let cache = crate::types::CacheControl::ephemeral();
1220+
1221+
// Exhaust all 4 breakpoints
1222+
for _ in 0..4 {
1223+
validator.validate(Some(&cache), CacheContext::user_message_part());
1224+
}
1225+
assert!(validator.is_at_limit());
1226+
1227+
let mut msg = user_msg("no room");
1228+
apply_tail_cache_to_message(&mut msg, &mut validator);
1229+
1230+
// Should remain a String (no conversion) since breakpoint was rejected
1231+
match &msg.content {
1232+
AnthropicMessageContent::String(s) => assert_eq!(s, "no room"),
1233+
_ => panic!("Should not convert to blocks when breakpoint limit exceeded"),
1234+
}
1235+
}
1236+
1237+
#[test]
1238+
fn test_tail_cache_after_merge_uses_one_breakpoint_for_merged_tool_results() {
1239+
// Scenario: 3 consecutive tool_result user messages merge into 1.
1240+
// Tail caching should use only 1 breakpoint (on the merged message),
1241+
// NOT 3 breakpoints (one per pre-merge message).
1242+
let mut validator = CacheControlValidator::new();
1243+
1244+
let mut merged = [
1245+
assistant_blocks_msg(vec![
1246+
tool_use_block("t1", "tool_a"),
1247+
tool_use_block("t2", "tool_b"),
1248+
tool_use_block("t3", "tool_c"),
1249+
]),
1250+
// Simulate 3 tool_result messages already merged into 1
1251+
user_blocks_msg(vec![
1252+
tool_result_block("t1", "result_a"),
1253+
tool_result_block("t2", "result_b"),
1254+
tool_result_block("t3", "result_c"),
1255+
]),
1256+
];
1257+
1258+
// Apply tail caching with tail_count=2 (both messages)
1259+
let len = merged.len();
1260+
let cache_start = len.saturating_sub(2);
1261+
for msg in &mut merged[cache_start..] {
1262+
apply_tail_cache_to_message(msg, &mut validator);
1263+
}
1264+
1265+
// Should use exactly 2 breakpoints (one per merged message)
1266+
assert_eq!(
1267+
validator.breakpoint_count(),
1268+
2,
1269+
"Should use 2 breakpoints, not more"
1270+
);
1271+
1272+
// Assistant message: last block (tool_use t3) should be cached
1273+
if let AnthropicMessageContent::Blocks(blocks) = &merged[0].content {
1274+
match &blocks[2] {
1275+
AnthropicContent::ToolUse { cache_control, .. } => {
1276+
assert!(cache_control.is_some(), "Last tool_use should be cached");
1277+
}
1278+
_ => panic!("Expected ToolUse"),
1279+
}
1280+
// First two should NOT be cached
1281+
for block in &blocks[..2] {
1282+
match block {
1283+
AnthropicContent::ToolUse { cache_control, .. } => {
1284+
assert!(cache_control.is_none());
1285+
}
1286+
_ => panic!("Expected ToolUse"),
1287+
}
1288+
}
1289+
}
1290+
1291+
// User message: last block (tool_result t3) should be cached
1292+
if let AnthropicMessageContent::Blocks(blocks) = &merged[1].content {
1293+
match &blocks[2] {
1294+
AnthropicContent::ToolResult { cache_control, .. } => {
1295+
assert!(cache_control.is_some(), "Last tool_result should be cached");
1296+
}
1297+
_ => panic!("Expected ToolResult"),
1298+
}
1299+
// First two should NOT be cached
1300+
for block in &blocks[..2] {
1301+
match block {
1302+
AnthropicContent::ToolResult { cache_control, .. } => {
1303+
assert!(cache_control.is_none());
1304+
}
1305+
_ => panic!("Expected ToolResult"),
1306+
}
1307+
}
1308+
}
1309+
}
1310+
1311+
#[test]
1312+
fn test_set_block_cache_control() {
1313+
let cc = AnthropicCacheControl::ephemeral();
1314+
1315+
// Text block
1316+
let mut block = text_block("hello");
1317+
set_block_cache_control(&mut block, Some(cc.clone()));
1318+
match &block {
1319+
AnthropicContent::Text { cache_control, .. } => assert!(cache_control.is_some()),
1320+
_ => panic!("Expected Text"),
1321+
}
1322+
1323+
// ToolResult block
1324+
let mut block = tool_result_block("t1", "result");
1325+
set_block_cache_control(&mut block, Some(cc.clone()));
1326+
match &block {
1327+
AnthropicContent::ToolResult { cache_control, .. } => {
1328+
assert!(cache_control.is_some())
1329+
}
1330+
_ => panic!("Expected ToolResult"),
1331+
}
1332+
1333+
// ToolUse block
1334+
let mut block = tool_use_block("t1", "tool_a");
1335+
set_block_cache_control(&mut block, Some(cc.clone()));
1336+
match &block {
1337+
AnthropicContent::ToolUse { cache_control, .. } => assert!(cache_control.is_some()),
1338+
_ => panic!("Expected ToolUse"),
1339+
}
1340+
}
10971341
}

0 commit comments

Comments
 (0)