Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 94 additions & 55 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,38 +176,11 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return nil, nil
}

// Copy mutable fields under lock to avoid races with SetTools/SetModels.
agentTools := a.tools.Copy()
// Cache-Aligned Summarization: use shared agent factory for prefix
// cache alignment with Summarize.
largeModel := a.largeModel.Get()
systemPrompt := a.systemPrompt.Get()
promptPrefix := a.systemPromptPrefix.Get()
var instructions strings.Builder

for _, server := range mcp.GetStates() {
if server.State != mcp.StateConnected {
continue
}
if s := server.Client.InitializeResult().Instructions; s != "" {
instructions.WriteString(s)
instructions.WriteString("\n\n")
}
}

if s := instructions.String(); s != "" {
systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
}

if len(agentTools) > 0 {
// Add Anthropic caching to the last tool.
agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
}

agent := fantasy.NewAgent(
largeModel.Model,
fantasy.WithSystemPrompt(systemPrompt),
fantasy.WithTools(agentTools...),
fantasy.WithUserAgent(userAgent),
)
agent := a.buildAgent()

sessionLock := sync.Mutex{}
currentSession, err := a.sessions.Get(ctx, call.SessionID)
Expand Down Expand Up @@ -270,9 +243,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
FrequencyPenalty: call.FrequencyPenalty,
PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = options.Messages
for i := range prepared.Messages {
prepared.Messages[i].ProviderOptions = nil
}

// Use latest tools (updated by SetTools when MCP tools change).
prepared.Tools = a.tools.Copy()
Expand All @@ -289,21 +259,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy

prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)

lastSystemRoleInx := 0
systemMessageUpdated := false
for i, msg := range prepared.Messages {
// Only add cache control to the last message.
if msg.Role == fantasy.MessageRoleSystem {
lastSystemRoleInx = i
} else if !systemMessageUpdated {
prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
systemMessageUpdated = true
}
// Than add cache control to the last 2 messages.
if i > len(prepared.Messages)-3 {
prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
}
}
// Apply shared cache control markers.
a.applyCacheControl(prepared.Messages)

if promptPrefix != "" {
prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
Expand Down Expand Up @@ -626,9 +583,11 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return ErrSessionBusy
}

// Copy mutable fields under lock to avoid races with SetModels.
// Cache-Aligned Summarization: use shared agent factory for prefix
// cache alignment with Run().
largeModel := a.largeModel.Get()
systemPromptPrefix := a.systemPromptPrefix.Get()
agent := a.buildAgent()

currentSession, err := a.sessions.Get(ctx, sessionID)
if err != nil {
Expand All @@ -650,10 +609,6 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
defer a.activeRequests.Del(sessionID)
defer cancel()

agent := fantasy.NewAgent(largeModel.Model,
fantasy.WithSystemPrompt(string(summaryPrompt)),
fantasy.WithUserAgent(userAgent),
)
summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Model: largeModel.Model.Model(),
Expand All @@ -672,6 +627,11 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
ProviderOptions: opts,
PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = options.Messages
prepared.Tools = a.tools.Copy()

// Apply shared cache control markers.
a.applyCacheControl(prepared.Messages)

if systemPromptPrefix != "" {
prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
}
Expand Down Expand Up @@ -735,6 +695,82 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return err
}

// buildSystemPrompt returns the system prompt with MCP instructions.
// Used by both the main agent path and summarizer to ensure identical
// prompts for prefix caching.
func (a *sessionAgent) buildSystemPrompt() string {
systemPrompt := a.systemPrompt.Get()

var instructions strings.Builder
for _, server := range mcp.GetStates() {
if server.State != mcp.StateConnected {
continue
}
if s := server.Client.InitializeResult().Instructions; s != "" {
instructions.WriteString(s)
instructions.WriteString("\n\n")
}
}

if s := instructions.String(); s != "" {
systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
}

return systemPrompt
}

// buildAgent creates a fantasy.Agent with the shared system prompt, tools
// (with cache control on the last tool), model, and user agent.
//
// Cache-Aligned Summarization: both processUserMessage and Summarize call
// buildAgent so the agent prefix (system prompt + tools + cache-control
// markers) is byte-for-byte identical. This guarantees a prefix-cache hit
// on the summarization request for any provider that supports prompt caching,
// avoiding re-processing of the shared prefix. In practice this saves
// roughly 85% of input token cost per compaction compared to an uncached
// request.
func (a *sessionAgent) buildAgent() fantasy.Agent {
agentTools := a.tools.Copy()
if len(agentTools) > 0 {
agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
}

return fantasy.NewAgent(
a.largeModel.Get().Model,
fantasy.WithSystemPrompt(a.buildSystemPrompt()),
fantasy.WithTools(agentTools...),
fantasy.WithUserAgent(userAgent),
)
}

// applyCacheControl is the second half of Cache-Aligned Summarization.
// It clears stale provider options from messages and marks cache-control
// breakpoints on the system message boundary and the last 2 messages.
// Both PrepareStep closures must call this to keep prefix cache alignment
// identical between processUserMessage and Summarize.
func (a *sessionAgent) applyCacheControl(messages []fantasy.Message) {
// Clear stale provider options from all messages.
for i := range messages {
messages[i].ProviderOptions = nil
}

// Mark the last system message with cache control.
lastSystemIdx := 0
systemMarked := false
for i, msg := range messages {
if msg.Role == fantasy.MessageRoleSystem {
lastSystemIdx = i
} else if !systemMarked {
messages[lastSystemIdx].ProviderOptions = a.getCacheControlOptions()
systemMarked = true
}
// Mark the last 2 messages with cache control.
if i > len(messages)-3 {
messages[i].ProviderOptions = a.getCacheControlOptions()
}
}
}

func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
return fantasy.ProviderOptions{}
Expand Down Expand Up @@ -1310,10 +1346,13 @@ func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Mes
return convertedMessages
}

// buildSummaryPrompt constructs the prompt text for session summarization.
// buildSummaryPrompt constructs the user-facing prompt for session
// summarization. The summary instructions (from templates/summary.md) are
// included here instead of as a system prompt, so the main agent's system
// prompt and tools can be reused for prompt cache hits.
func buildSummaryPrompt(todos []session.Todo) string {
var sb strings.Builder
sb.WriteString("Provide a detailed summary of our conversation above.")
sb.Write(summaryPrompt)
if len(todos) > 0 {
sb.WriteString("\n\n## Current Todo List\n\n")
for _, t := range todos {
Expand Down
2 changes: 1 addition & 1 deletion internal/agent/templates/summary.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
You are summarizing a conversation to preserve context for continuing work later.
The conversation context above needs to be compacted. Summarize the key information into a detailed context summary so work can continue without losing important details.

**Critical**: This summary will be the ONLY context available when the conversation resumes. Assume all previous messages will be lost. Be thorough.

Expand Down
Loading