diff --git a/internal/core/exec/messages.go b/internal/core/exec/messages.go index 6d5bcc333..3d77f551b 100644 --- a/internal/core/exec/messages.go +++ b/internal/core/exec/messages.go @@ -57,6 +57,8 @@ type LLMMessageMetadata struct { CompletionTokens int `json:"completionTokens,omitempty"` // TotalTokens is the sum of prompt and completion tokens. TotalTokens int `json:"totalTokens,omitempty"` + // Cost is the estimated USD cost for this API call. + Cost float64 `json:"cost,omitempty"` } // ToolDefinition represents a tool that was available to the LLM. diff --git a/internal/core/exec/messages_test.go b/internal/core/exec/messages_test.go index 548052de6..08c1768ec 100644 --- a/internal/core/exec/messages_test.go +++ b/internal/core/exec/messages_test.go @@ -1,9 +1,11 @@ package exec import ( + "encoding/json" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDeduplicateSystemMessages(t *testing.T) { @@ -68,3 +70,46 @@ func TestDeduplicateSystemMessages(t *testing.T) { }) } } + +func TestLLMMessageMetadata_CostJSONRoundTrip(t *testing.T) { + t.Run("CostSerializesWithKey", func(t *testing.T) { + meta := LLMMessageMetadata{ + Provider: "openai", + Model: "gpt-4", + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + Cost: 0.0042, + } + + data, err := json.Marshal(meta) + require.NoError(t, err) + + // Verify "cost" key is present in JSON + assert.Contains(t, string(data), `"cost":`) + + var decoded LLMMessageMetadata + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, meta, decoded) + }) + + t.Run("ZeroCostOmittedFromJSON", func(t *testing.T) { + meta := LLMMessageMetadata{ + Provider: "openai", + Model: "gpt-4", + } + + data, err := json.Marshal(meta) + require.NoError(t, err) + + // "cost" should be omitted when zero (omitempty) + assert.NotContains(t, string(data), `"cost"`) + + var decoded LLMMessageMetadata + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + assert.InDelta(t, 0.0, decoded.Cost, 1e-9) + }) +} diff --git a/internal/runtime/builtin/agentstep/convert.go b/internal/runtime/builtin/agentstep/convert.go new file mode 100644 index 000000000..bf2bd0de9 --- /dev/null +++ b/internal/runtime/builtin/agentstep/convert.go @@ -0,0 +1,123 @@ +package agentstep + +import ( + "github.com/dagu-org/dagu/internal/agent" + "github.com/dagu-org/dagu/internal/core/exec" + "github.com/dagu-org/dagu/internal/llm" +) + +// convertMessage converts an agent.Message to one or more exec.LLMMessage values. +// User messages with tool results expand to one message per result (Role=Tool). +func convertMessage(msg agent.Message, modelCfg *agent.ModelConfig) []exec.LLMMessage { + switch msg.Type { + case agent.MessageTypeAssistant: + return []exec.LLMMessage{convertAssistantMessage(msg, modelCfg)} + + case agent.MessageTypeUser: + // When ToolResults are present, the message is a tool-result payload; Content is unused. + if len(msg.ToolResults) > 0 { + return convertToolResultMessages(msg) + } + return []exec.LLMMessage{{ + Role: exec.RoleUser, + Content: msg.Content, + }} + + case agent.MessageTypeError: + return []exec.LLMMessage{{ + Role: exec.RoleAssistant, + Content: msg.Content, + }} + + default: + return nil + } +} + +// convertAssistantMessage converts an assistant agent.Message to an exec.LLMMessage. +func convertAssistantMessage(msg agent.Message, modelCfg *agent.ModelConfig) exec.LLMMessage { + m := exec.LLMMessage{ + Role: exec.RoleAssistant, + Content: msg.Content, + } + + // Convert tool calls if present. + if len(msg.ToolCalls) > 0 { + m.ToolCalls = make([]exec.ToolCall, len(msg.ToolCalls)) + for i, tc := range msg.ToolCalls { + m.ToolCalls[i] = exec.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: exec.ToolCallFunction{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + } + } + } + + // Build metadata with provider/model and optional usage/cost. + metadata := &exec.LLMMessageMetadata{ + Provider: modelCfg.Provider, + Model: modelCfg.Model, + } + if msg.Usage != nil { + metadata.PromptTokens = msg.Usage.PromptTokens + metadata.CompletionTokens = msg.Usage.CompletionTokens + metadata.TotalTokens = msg.Usage.TotalTokens + } + if msg.Cost != nil { + metadata.Cost = *msg.Cost + } + m.Metadata = metadata + + return m +} + +// contextToLLMHistory converts context messages to llm.Message for LoopConfig.History. +// System messages are filtered out since the loop handles system prompt separately. +func contextToLLMHistory(msgs []exec.LLMMessage) []llm.Message { + if len(msgs) == 0 { + return nil + } + var result []llm.Message + for _, msg := range msgs { + if msg.Role == exec.RoleSystem { + continue + } + m := llm.Message{ + Role: llm.Role(msg.Role), + Content: msg.Content, + ToolCallID: msg.ToolCallID, + } + if len(msg.ToolCalls) > 0 { + m.ToolCalls = make([]llm.ToolCall, len(msg.ToolCalls)) + for j, tc := range msg.ToolCalls { + m.ToolCalls[j] = llm.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: llm.ToolCallFunction{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + } + } + } + result = append(result, m) + } + return result +} + +// convertToolResultMessages converts a user message with tool results +// into one exec.LLMMessage per tool result (Role=Tool). +func convertToolResultMessages(msg agent.Message) []exec.LLMMessage { + msgs := make([]exec.LLMMessage, len(msg.ToolResults)) + for i, tr := range msg.ToolResults { + msgs[i] = exec.LLMMessage{ + Role: exec.RoleTool, + Content: tr.Content, + ToolCallID: tr.ToolCallID, + } + } + return msgs +} diff --git a/internal/runtime/builtin/agentstep/convert_test.go b/internal/runtime/builtin/agentstep/convert_test.go new file mode 100644 index 000000000..2298bf091 --- /dev/null +++ b/internal/runtime/builtin/agentstep/convert_test.go @@ -0,0 +1,180 @@ +package agentstep + +import ( + "testing" + + "github.com/dagu-org/dagu/internal/agent" + "github.com/dagu-org/dagu/internal/core/exec" + "github.com/dagu-org/dagu/internal/llm" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testModelConfig() *agent.ModelConfig { + return &agent.ModelConfig{ + Provider: "openai", + Model: "gpt-4", + } +} + +func TestConvertMessage_AssistantWithUsageAndCost(t *testing.T) { + t.Parallel() + + cost := 0.0042 + msg := agent.Message{ + Type: agent.MessageTypeAssistant, + Content: "hello world", + Usage: &llm.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + Cost: &cost, + } + + result := convertMessage(msg, testModelConfig()) + + require.Len(t, result, 1) + m := result[0] + assert.Equal(t, exec.RoleAssistant, m.Role) + assert.Equal(t, "hello world", m.Content) + require.NotNil(t, m.Metadata) + assert.Equal(t, "openai", m.Metadata.Provider) + assert.Equal(t, "gpt-4", m.Metadata.Model) + assert.Equal(t, 100, m.Metadata.PromptTokens) + assert.Equal(t, 50, m.Metadata.CompletionTokens) + assert.Equal(t, 150, m.Metadata.TotalTokens) + assert.InDelta(t, 0.0042, m.Metadata.Cost, 1e-9) +} + +func TestConvertMessage_AssistantWithToolCalls(t *testing.T) { + t.Parallel() + + msg := agent.Message{ + Type: agent.MessageTypeAssistant, + Content: "", + ToolCalls: []llm.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: llm.ToolCallFunction{ + Name: "bash", + Arguments: `{"command":"ls"}`, + }, + }, + { + ID: "call_2", + Type: "function", + Function: llm.ToolCallFunction{ + Name: "read", + Arguments: `{"path":"/tmp/test"}`, + }, + }, + }, + } + + result := convertMessage(msg, testModelConfig()) + + require.Len(t, result, 1) + m := result[0] + assert.Equal(t, exec.RoleAssistant, m.Role) + require.Len(t, m.ToolCalls, 2) + + assert.Equal(t, "call_1", m.ToolCalls[0].ID) + assert.Equal(t, "function", m.ToolCalls[0].Type) + assert.Equal(t, "bash", m.ToolCalls[0].Function.Name) + assert.Equal(t, `{"command":"ls"}`, m.ToolCalls[0].Function.Arguments) + + assert.Equal(t, "call_2", m.ToolCalls[1].ID) + assert.Equal(t, "read", m.ToolCalls[1].Function.Name) + assert.Equal(t, `{"path":"/tmp/test"}`, m.ToolCalls[1].Function.Arguments) +} + +func TestConvertMessage_UserNoToolResults(t *testing.T) { + t.Parallel() + + msg := agent.Message{ + Type: agent.MessageTypeUser, + Content: "user input", + } + + result := convertMessage(msg, testModelConfig()) + + require.Len(t, result, 1) + assert.Equal(t, exec.RoleUser, result[0].Role) + assert.Equal(t, "user input", result[0].Content) + assert.Nil(t, result[0].Metadata) +} + +func TestConvertMessage_UserWithToolResults(t *testing.T) { + t.Parallel() + + msg := agent.Message{ + Type: agent.MessageTypeUser, + ToolResults: []agent.ToolResult{ + {ToolCallID: "call_1", Content: "result 1"}, + {ToolCallID: "call_2", Content: "result 2"}, + }, + } + + result := convertMessage(msg, testModelConfig()) + + require.Len(t, result, 2) + + assert.Equal(t, exec.RoleTool, result[0].Role) + assert.Equal(t, "result 1", result[0].Content) + assert.Equal(t, "call_1", result[0].ToolCallID) + + assert.Equal(t, exec.RoleTool, result[1].Role) + assert.Equal(t, "result 2", result[1].Content) + assert.Equal(t, "call_2", result[1].ToolCallID) +} + +func TestConvertMessage_NilUsageAndCost(t *testing.T) { + t.Parallel() + + msg := agent.Message{ + Type: agent.MessageTypeAssistant, + Content: "response", + } + + result := convertMessage(msg, testModelConfig()) + + require.Len(t, result, 1) + m := result[0] + require.NotNil(t, m.Metadata) + assert.Equal(t, "openai", m.Metadata.Provider) + assert.Equal(t, "gpt-4", m.Metadata.Model) + assert.Equal(t, 0, m.Metadata.PromptTokens) + assert.Equal(t, 0, m.Metadata.CompletionTokens) + assert.Equal(t, 0, m.Metadata.TotalTokens) + assert.InDelta(t, 0.0, m.Metadata.Cost, 1e-9) +} + +func TestConvertMessage_ErrorType(t *testing.T) { + t.Parallel() + + msg := agent.Message{ + Type: agent.MessageTypeError, + Content: "something went wrong", + } + + result := convertMessage(msg, testModelConfig()) + + require.Len(t, result, 1) + assert.Equal(t, exec.RoleAssistant, result[0].Role) + assert.Equal(t, "something went wrong", result[0].Content) + assert.Nil(t, result[0].Metadata, "error messages should not carry LLM metadata") +} + +func TestConvertMessage_UnknownType(t *testing.T) { + t.Parallel() + + msg := agent.Message{ + Type: agent.MessageTypeUIAction, + Content: "navigate", + } + + result := convertMessage(msg, testModelConfig()) + assert.Nil(t, result) +} diff --git a/internal/runtime/builtin/agentstep/executor.go b/internal/runtime/builtin/agentstep/executor.go index 5e56b6ebe..0f10139dd 100644 --- a/internal/runtime/builtin/agentstep/executor.go +++ b/internal/runtime/builtin/agentstep/executor.go @@ -24,6 +24,7 @@ import ( ) var _ executor.Executor = (*Executor)(nil) +var _ executor.ChatMessageHandler = (*Executor)(nil) func init() { executor.RegisterExecutor( @@ -36,11 +37,13 @@ func init() { // Executor runs the agent loop as a workflow step. type Executor struct { - step core.Step - stdout io.Writer - stderr io.Writer - mu sync.Mutex - cancelLoop context.CancelFunc + step core.Step + stdout io.Writer + stderr io.Writer + mu sync.Mutex + cancelLoop context.CancelFunc + contextMessages []exec.LLMMessage + savedMessages []exec.LLMMessage } func newAgentExecutor(_ context.Context, step core.Step) (executor.Executor, error) { @@ -50,6 +53,12 @@ func newAgentExecutor(_ context.Context, step core.Step) (executor.Executor, err func (e *Executor) SetStdout(w io.Writer) { e.stdout = w } func (e *Executor) SetStderr(w io.Writer) { e.stderr = w } +// SetContext sets the session context from prior steps. +func (e *Executor) SetContext(msgs []exec.LLMMessage) { e.contextMessages = msgs } + +// GetMessages returns the collected messages after execution. +func (e *Executor) GetMessages() []exec.LLMMessage { return e.savedMessages } + func (e *Executor) Kill(_ os.Signal) error { e.mu.Lock() defer e.mu.Unlock() @@ -200,12 +209,18 @@ func (e *Executor) Run(ctx context.Context) error { e.cancelLoop = cancelLoop e.mu.Unlock() + // Initialize savedMessages from context so GetMessages() returns the full chain. + if len(e.contextMessages) > 0 { + e.savedMessages = append([]exec.LLMMessage(nil), e.contextMessages...) + } + iteration := 0 loop := agent.NewLoop(agent.LoopConfig{ Provider: provider, Model: modelCfg.Model, Tools: tools, + History: contextToLLMHistory(e.contextMessages), SystemPrompt: systemPrompt, SafeMode: safeMode, Hooks: hooks, @@ -214,6 +229,8 @@ func (e *Executor) Run(ctx context.Context) error { AllowedSkills: allowedSkills, RecordMessage: func(_ context.Context, msg agent.Message) { logMessage(stderr, msg) + converted := convertMessage(msg, modelCfg) + e.savedMessages = append(e.savedMessages, converted...) }, OnWorking: func(working bool) { if !working { diff --git a/internal/runtime/builtin/agentstep/executor_test.go b/internal/runtime/builtin/agentstep/executor_test.go new file mode 100644 index 000000000..5b8363791 --- /dev/null +++ b/internal/runtime/builtin/agentstep/executor_test.go @@ -0,0 +1,107 @@ +package agentstep + +import ( + "testing" + + "github.com/dagu-org/dagu/internal/core/exec" + "github.com/dagu-org/dagu/internal/llm" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecutor_SetContextAndGetMessages(t *testing.T) { + t.Parallel() + + e := &Executor{} + + // Initially empty. + assert.Empty(t, e.GetMessages()) + + // SetContext stores messages. + msgs := []exec.LLMMessage{ + {Role: exec.RoleSystem, Content: "be helpful"}, + {Role: exec.RoleUser, Content: "hello"}, + } + e.SetContext(msgs) + assert.Equal(t, msgs, e.contextMessages) + + // GetMessages returns savedMessages, not contextMessages. + assert.Empty(t, e.GetMessages()) + + // Simulate saved messages after execution. + e.savedMessages = []exec.LLMMessage{ + {Role: exec.RoleUser, Content: "test"}, + {Role: exec.RoleAssistant, Content: "response"}, + } + assert.Len(t, e.GetMessages(), 2) + assert.Equal(t, "response", e.GetMessages()[1].Content) +} + +func TestContextToLLMHistory_NilInput(t *testing.T) { + t.Parallel() + assert.Nil(t, contextToLLMHistory(nil)) +} + +func TestContextToLLMHistory_EmptyInput(t *testing.T) { + t.Parallel() + assert.Nil(t, contextToLLMHistory([]exec.LLMMessage{})) +} + +func TestContextToLLMHistory_FiltersSystemMessages(t *testing.T) { + t.Parallel() + msgs := []exec.LLMMessage{ + {Role: exec.RoleSystem, Content: "system prompt"}, + {Role: exec.RoleUser, Content: "hello"}, + {Role: exec.RoleSystem, Content: "another system"}, + {Role: exec.RoleAssistant, Content: "hi"}, + } + result := contextToLLMHistory(msgs) + require.Len(t, result, 2) + assert.Equal(t, llm.RoleUser, result[0].Role) + assert.Equal(t, "hello", result[0].Content) + assert.Equal(t, llm.RoleAssistant, result[1].Role) + assert.Equal(t, "hi", result[1].Content) +} + +func TestContextToLLMHistory_ConvertsAllRoles(t *testing.T) { + t.Parallel() + msgs := []exec.LLMMessage{ + {Role: exec.RoleUser, Content: "question"}, + {Role: exec.RoleAssistant, Content: "answer"}, + {Role: exec.RoleTool, Content: "result", ToolCallID: "tc-1"}, + } + result := contextToLLMHistory(msgs) + require.Len(t, result, 3) + assert.Equal(t, llm.RoleUser, result[0].Role) + assert.Equal(t, llm.RoleAssistant, result[1].Role) + assert.Equal(t, llm.RoleTool, result[2].Role) + assert.Equal(t, "tc-1", result[2].ToolCallID) +} + +func TestContextToLLMHistory_ConvertsToolCalls(t *testing.T) { + t.Parallel() + msgs := []exec.LLMMessage{ + { + Role: exec.RoleAssistant, + Content: "let me check", + ToolCalls: []exec.ToolCall{ + { + ID: "call-1", + Type: "function", + Function: exec.ToolCallFunction{ + Name: "read", + Arguments: `{"path":"/tmp/f"}`, + }, + }, + }, + }, + } + result := contextToLLMHistory(msgs) + require.Len(t, result, 1) + require.Len(t, result[0].ToolCalls, 1) + tc := result[0].ToolCalls[0] + assert.Equal(t, "call-1", tc.ID) + assert.Equal(t, "function", tc.Type) + assert.Equal(t, "read", tc.Function.Name) + assert.Equal(t, `{"path":"/tmp/f"}`, tc.Function.Arguments) +} diff --git a/internal/runtime/builtin/agentstep/testing.go b/internal/runtime/builtin/agentstep/testing.go new file mode 100644 index 000000000..efd1c97db --- /dev/null +++ b/internal/runtime/builtin/agentstep/testing.go @@ -0,0 +1,69 @@ +package agentstep + +import ( + "context" + "io" + "os" + + "github.com/dagu-org/dagu/internal/core" + "github.com/dagu-org/dagu/internal/core/exec" + "github.com/dagu-org/dagu/internal/runtime/executor" +) + +// MockExecutorType is a test executor type that simulates a successful agent step. +const MockExecutorType = "mock-agent" + +// MockExecutor is a mock implementation for testing agent step message flow. +type MockExecutor struct { + stdout io.Writer + stderr io.Writer + contextMessages []exec.LLMMessage + messages []exec.LLMMessage +} + +var _ executor.Executor = (*MockExecutor)(nil) +var _ executor.ChatMessageHandler = (*MockExecutor)(nil) + +// NewMockExecutor creates a new mock agent executor. +func NewMockExecutor(_ context.Context, _ core.Step) (executor.Executor, error) { + return &MockExecutor{ + stdout: os.Stdout, + stderr: os.Stderr, + }, nil +} + +func (m *MockExecutor) SetStdout(out io.Writer) { m.stdout = out } +func (m *MockExecutor) SetStderr(out io.Writer) { m.stderr = out } +func (m *MockExecutor) Kill(_ os.Signal) error { return nil } +func (m *MockExecutor) Run(_ context.Context) error { + // Clone context messages, then append this step's own messages. + if len(m.contextMessages) > 0 { + m.messages = append([]exec.LLMMessage(nil), m.contextMessages...) + } + m.messages = append(m.messages, + exec.LLMMessage{Role: exec.RoleUser, Content: "agent input"}, + exec.LLMMessage{ + Role: exec.RoleAssistant, + Content: "agent response", + Metadata: &exec.LLMMessageMetadata{ + Provider: "openai", + Model: "gpt-4", + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + Cost: 0.001, + }, + }, + ) + _, _ = m.stdout.Write([]byte("mock agent response\n")) + return nil +} +func (m *MockExecutor) SetContext(msgs []exec.LLMMessage) { + m.contextMessages = msgs +} +func (m *MockExecutor) GetMessages() []exec.LLMMessage { return m.messages } + +// RegisterMockExecutors registers mock agent executors for testing. +func RegisterMockExecutors() { + executor.RegisterExecutor(MockExecutorType, NewMockExecutor, nil, core.ExecutorCapabilities{Agent: true}) +} diff --git a/internal/runtime/runner.go b/internal/runtime/runner.go index 866759a3a..174966936 100644 --- a/internal/runtime/runner.go +++ b/internal/runtime/runner.go @@ -514,7 +514,9 @@ func (r *Runner) setupChatMessages(ctx context.Context, node *Node) { } step := node.Step() - if !core.SupportsLLM(step.ExecutorConfig.Type) { + + executorType := step.ExecutorConfig.Type + if !core.SupportsLLM(executorType) && !core.SupportsAgent(executorType) { return } @@ -547,18 +549,13 @@ func (r *Runner) saveChatMessages(ctx context.Context, node *Node) { return } - step := node.Step() - if !core.SupportsLLM(step.ExecutorConfig.Type) { - return - } - savedMsgs := node.GetChatMessages() if len(savedMsgs) == 0 { return } // Direct write - no read-modify-write cycle - if err := r.messagesHandler.WriteStepMessages(ctx, step.Name, savedMsgs); err != nil { + if err := r.messagesHandler.WriteStepMessages(ctx, node.Step().Name, savedMsgs); err != nil { logger.Warn(ctx, "Failed to write chat messages", tag.Error(err)) } } diff --git a/internal/runtime/runner_helper_test.go b/internal/runtime/runner_helper_test.go index b3ba75961..25b0c20fc 100644 --- a/internal/runtime/runner_helper_test.go +++ b/internal/runtime/runner_helper_test.go @@ -12,6 +12,7 @@ import ( "github.com/dagu-org/dagu/internal/core" "github.com/dagu-org/dagu/internal/core/exec" "github.com/dagu-org/dagu/internal/runtime" + "github.com/dagu-org/dagu/internal/runtime/builtin/agentstep" "github.com/dagu-org/dagu/internal/runtime/builtin/chat" "github.com/dagu-org/dagu/internal/test" "github.com/google/uuid" @@ -385,6 +386,11 @@ func chatStep(name string, depends ...string) core.Step { return newStep(name, withDepends(depends...), withExecutorType(core.ExecutorTypeChat)) } +func agentStep(name string, depends ...string) core.Step { + return newStep(name, withDepends(depends...), withExecutorType(core.ExecutorTypeAgent)) +} + func init() { chat.RegisterMockExecutors() + agentstep.RegisterMockExecutors() } diff --git a/internal/runtime/runner_test.go b/internal/runtime/runner_test.go index 5abfd23e2..d903ad98b 100644 --- a/internal/runtime/runner_test.go +++ b/internal/runtime/runner_test.go @@ -14,6 +14,7 @@ import ( "github.com/dagu-org/dagu/internal/core" "github.com/dagu-org/dagu/internal/core/exec" "github.com/dagu-org/dagu/internal/runtime" + "github.com/dagu-org/dagu/internal/runtime/builtin/agentstep" "github.com/dagu-org/dagu/internal/runtime/builtin/chat" "github.com/dagu-org/dagu/internal/test" "github.com/google/uuid" @@ -3011,6 +3012,73 @@ func TestRunner_ChatMessagesHandler(t *testing.T) { assert.Equal(t, 0, handler.writeCalls) }) + + t.Run("AgentStepSavesMessages", func(t *testing.T) { + t.Parallel() + + handler := newMockMessagesHandler() + r := setupRunner(t, withMessagesHandler(handler)) + + plan := r.newPlan(t, newStep("agent1", withExecutorType(agentstep.MockExecutorType))) + result := plan.assertRun(t, core.Succeeded) + result.assertNodeStatus(t, "agent1", core.NodeSucceeded) + + assert.Equal(t, 1, handler.writeCalls) + assert.NotEmpty(t, handler.messages["agent1"]) + + // Verify cost metadata was preserved + msgs := handler.messages["agent1"] + var foundCost bool + for _, m := range msgs { + if m.Metadata != nil && m.Metadata.Cost > 0 { + foundCost = true + assert.Equal(t, "openai", m.Metadata.Provider) + assert.Equal(t, "gpt-4", m.Metadata.Model) + assert.InDelta(t, 0.001, m.Metadata.Cost, 1e-9) + } + } + assert.True(t, foundCost, "expected at least one message with cost metadata") + }) + + t.Run("AgentStepInheritsFromDependency", func(t *testing.T) { + t.Parallel() + + handler := newMockMessagesHandler() + handler.messages["step1"] = []exec.LLMMessage{ + {Role: exec.RoleSystem, Content: "be helpful"}, + {Role: exec.RoleUser, Content: "prior message"}, + } + + r := setupRunner(t, withMessagesHandler(handler)) + + plan := r.newPlan(t, + successStep("step1"), + newStep("agent1", withDepends("step1"), withExecutorType(agentstep.MockExecutorType)), + ) + result := plan.assertRun(t, core.Succeeded) + result.assertNodeStatus(t, "agent1", core.NodeSucceeded) + + assert.Equal(t, 1, handler.writeCalls) + // The mock prepends inherited context, so saved messages should contain the inherited ones + msgs := handler.messages["agent1"] + assert.True(t, len(msgs) > 2, "expected inherited + own messages") + }) + + t.Run("HandlerNotCalledForAgentStepWithNoMessages", func(t *testing.T) { + t.Parallel() + + handler := newMockMessagesHandler() + r := setupRunner(t, withMessagesHandler(handler)) + + // agentStep helper creates a step with executor type "agent" (real executor), + // which will fail since no agent config is available — but the gate should + // allow the step through (setup/save calls won't panic). + plan := r.newPlan(t, agentStep("agent_fail")) + _ = plan.assertRun(t, core.Failed) + + // Handler must not be called: step failed so saveChatMessages is skipped. + assert.Equal(t, 0, handler.writeCalls) + }) } func TestWaitStep(t *testing.T) { diff --git a/rfcs/021-llm-cost-tracking.md b/rfcs/021-llm-cost-tracking.md new file mode 100644 index 000000000..0f5b1669b --- /dev/null +++ b/rfcs/021-llm-cost-tracking.md @@ -0,0 +1,198 @@ +--- +id: "021" +title: "LLM API Cost Tracking" +status: draft +--- + +# RFC 021: LLM API Cost Tracking + +## Summary + +Add a cost tracking feature that persists per-session LLM costs and exposes a monthly per-user cost summary via a new API endpoint and UI page. Covers agent chat sessions and DAG-embedded LLM steps (`chat` and `agentstep` executors). Depends on RFC 022 (Agent Step Cost Persistence) for executor-level cost data. Reuses existing session and DAG run data — no new database or storage backend required. + +--- + +## Motivation + +The agent already captures per-message cost (`Message.Cost`) and token usage (`Message.Usage`) in session JSON files. `SessionManager.totalCost` accumulates cost in-memory during a session, but this value is lost on restart. Additionally, the `chat` executor stores token usage metadata in DAG run message files but never computes a USD cost. There is currently no way to: + +1. **View aggregated costs** — administrators cannot see how much LLM usage costs per user or per month. +2. **Attribute costs** — there is no breakdown of which users are driving LLM spend. +3. **Recover session cost** — restarting the server loses the in-memory `totalCost`; the only way to recover it is to re-sum all message costs. +4. **Track DAG step LLM costs** — the `chat` executor records token counts but not USD cost; the `agentstep` executor doesn't persist cost at all (only logs to stderr). See RFC 022 for executor-level changes. + +### Use Cases + +- An **administrator** views the monthly cost dashboard to monitor LLM spend across all cost sources (agent chat, chat steps, agent steps). +- A **manager** reviews cost trends to decide whether to adjust model selection or usage policies. +- A **developer** checks their own usage to stay aware of personal LLM consumption. + +--- + +## Proposal + +### 1. Persist `TotalCost` in Session Files + +Add a `TotalCost` field to `SessionForStorage` and `Session`: + +```go +// internal/persis/filesession/store.go +type SessionForStorage struct { + // ... existing fields ... + TotalCost float64 `json:"total_cost,omitempty"` +} + +// internal/agent/types.go +type Session struct { + // ... existing fields ... + TotalCost float64 `json:"total_cost,omitempty"` +} +``` + +Update `AddMessage()` in the file session store to accumulate cost: + +```go +func (s *Store) AddMessage(_ context.Context, sessionID string, msg *agent.Message) error { + // ... existing logic ... + stored.Messages = append(stored.Messages, *msg) + if msg.Cost != nil { + stored.TotalCost += *msg.Cost + } + // ... write file ... +} +``` + +Update `ToSession()` and `FromSession()` to include `TotalCost`. + +**Backward compatibility:** When loading old session files where `TotalCost` is zero but messages have costs, compute `TotalCost` by summing `Message.Cost` values. This ensures existing sessions are handled correctly without migration. + +### 2. Executor-Level Cost Persistence (RFC 022) + +RFC 022 handles adding `Cost` to `LLMMessageMetadata`, implementing `ChatMessageHandler` on the `agentstep` executor, and populating cost in the `chat` executor. This RFC assumes that work is complete — per-message cost data is available in DAG run message files for both executor types. + +### 3. New API Endpoint + +```http +GET /api/v1/agent/cost-summary?month=YYYY-MM&userId=optional +``` + +**Response schema:** + +```yaml +AgentCostSummary: + type: object + properties: + month: + type: string + description: "The queried month (YYYY-MM)" + entries: + type: array + items: + $ref: "#/components/schemas/AgentCostEntry" + totalCost: + type: number + format: double + +AgentCostEntry: + type: object + properties: + userId: + type: string + sessionCount: + type: integer + totalTokens: + type: integer + description: "Sum of input + output tokens" + source: + type: string + enum: [agent_chat, chat_step, agent_step] + description: "Cost origin" + totalCost: + type: number + format: double +``` + +**Implementation approach:** + +1. Add `ListUserIDs() []string` to the file session store (reads from the `byUser` in-memory index — no disk I/O). +2. Add a `GetCostSummary(ctx, month, userID)` method in `internal/agent/api.go` that: + - Lists all user IDs (or filters to one if `userId` is specified). + - For each user, lists sessions via `ListSessions`. + - Filters sessions by `CreatedAt` month. + - Sums `TotalCost` from the `Session` struct (with fallback: if `TotalCost` is zero, load messages and sum `Message.Cost`). + - Aggregates token counts from `Message.Usage`. + - Additionally scans DAG run message files for the queried month to include `chat` executor costs (see Section 2). +3. Add a handler in `internal/service/frontend/api/v1/agent_cost.go` wired to the Chi router. + +**Cost source breakdown in response:** + +The response includes a `source` field per entry to distinguish cost origins: + +| Source | Description | Data Location | +|--------|-------------|---------------| +| `agent_chat` | Interactive agent chat sessions | Agent session files (`SessionStore`) | +| `chat_step` | DAG `chat` executor messages | DAG run message files (`{dag-run-dir}/messages/`) via `LLMMessageMetadata.Cost` (RFC 022) | +| `agent_step` | DAG `agentstep` executor messages | DAG run message files (same path as chat_step, via `ChatMessageHandler` — RFC 022) | + +All three sources work identically in both local and shared-nothing worker modes. The `chat_step` and `agent_step` data flows through the existing `node.ChatMessages` → `DAGRunStatus` → `ReportStatus()` → `persistChatMessages()` pipeline. See RFC 022 for shared-nothing compatibility details. + +### 4. Permission Model + +Follows the audit log permission pattern: + +| Role | Access | +|------|--------| +| Admin | All users' costs | +| Manager | All users' costs | +| Operator | Own costs only | +| Viewer | Own costs only | +| Developer | Own costs only | + +The handler uses `requireAuthenticated` from the existing auth middleware. Non-manager users have the `userId` parameter forced to their own ID, ensuring they can only view their own costs. Admin and Manager roles may specify any `userId` or omit it to see all users. + +### 5. UI Page + +A new page at route `/agent-cost` following the audit-logs page pattern: + +- **Month picker** — defaults to current month, allows navigating previous months. +- **Table columns** — User, Sessions, Total Tokens, Total Cost (USD). +- **Totals row** — aggregate across all visible users. +- **Visibility** — shown in the sidebar under the "Operations" section, gated by `canViewAuditLogs` (reuses existing permission check). + +**Files:** + +| File | Change | +|------|--------| +| `ui/src/pages/agent-cost/index.tsx` | New page component | +| `ui/src/menu.tsx` | Add nav item under Operations | +| `ui/src/App.tsx` | Add route | +| `api/v1/api.yaml` | Add endpoint + schemas | + +--- + +## Implementation Files + +| File | Change | +|------|--------| +| `internal/agent/types.go` | Add `TotalCost` to `Session` | +| `internal/persis/filesession/store.go` | Add `TotalCost` to `SessionForStorage`, accumulate in `AddMessage`, backward-compat sum in `ToSession` | +| `internal/agent/store.go` | Add `ListUserIDs` to `SessionStore` interface | +| `internal/agent/api.go` | Add `GetCostSummary` method (aggregates sessions + DAG run messages) | +| `api/v1/api.yaml` | Add `/agent/cost-summary` endpoint + `AgentCostSummary` / `AgentCostEntry` schemas | +| `internal/service/frontend/api/v1/agent_cost.go` | New handler with permission check | +| `ui/src/pages/agent-cost/index.tsx` | New UI page | +| `ui/src/menu.tsx` | Add nav item | +| `ui/src/App.tsx` | Add route | + +**Dependency:** RFC 022 covers executor-level changes (`messages.go`, `chat/executor.go`, `agentstep/executor.go`). + +--- + +## Out of Scope + +- **Budgets and alerts** — no spending limits or threshold notifications. +- **Real-time streaming** — cost summary is a point-in-time query, not SSE. +- **Export** — no CSV/JSON export of cost data. +- **Charts/graphs** — table only; visualization can be added later. +- **Per-model breakdown** — aggregated per-user, not split by model. +- **Cost forecasting** — no trend analysis or projection. diff --git a/rfcs/022-agentstep-cost-persistence.md b/rfcs/022-agentstep-cost-persistence.md new file mode 100644 index 000000000..8b30f378d --- /dev/null +++ b/rfcs/022-agentstep-cost-persistence.md @@ -0,0 +1,102 @@ +--- +id: "022" +title: "Agent Step Cost Persistence" +status: draft +--- + +# RFC 022: Agent Step Cost Persistence + +## Summary + +Persist LLM cost and token usage data from `agentstep` executor runs into DAG run files by implementing the `ChatMessageHandler` interface. Also add a `Cost` field to `LLMMessageMetadata` so both `chat` and `agentstep` executors can record per-message USD cost. + +## Motivation + +The `agentstep` executor receives per-message cost (`agent.Message.Cost`) and token usage (`agent.Message.Usage`) via the `RecordMessage` callback, but only logs them to stderr. This data is lost after execution. The `chat` executor already persists messages via `ChatMessageHandler` → `node.ChatMessages` → DAG run files, but records token counts without USD cost (since `LLMMessageMetadata` has no `Cost` field). + +RFC 021 (LLM Cost Tracking) needs per-message cost data from both executors to build the cost summary API. This RFC extracts the executor-level persistence changes as a standalone prerequisite. + +## Proposal + +### 1. Add `Cost` to `LLMMessageMetadata` + +In `internal/core/exec/messages.go`: + +```go +type LLMMessageMetadata struct { + // ... existing fields ... + Cost float64 `json:"cost,omitempty"` // USD cost for this API call +} +``` + +### 2. Implement `ChatMessageHandler` on `agentstep` Executor + +In `internal/runtime/builtin/agentstep/executor.go`: + +Add fields to `Executor`: + +```go +type Executor struct { + // ... existing fields ... + contextMessages []exec.LLMMessage + savedMessages []exec.LLMMessage +} +``` + +Implement the interface (`internal/runtime/executor/executor.go:96`): + +```go +func (e *Executor) SetContext(msgs []exec.LLMMessage) { e.contextMessages = msgs } +func (e *Executor) GetMessages() []exec.LLMMessage { return e.savedMessages } +``` + +In the `RecordMessage` callback (line 215), convert `agent.Message` → `exec.LLMMessage`: + +```go +RecordMessage: func(_ context.Context, msg agent.Message) { + logMessage(stderr, msg) + e.savedMessages = append(e.savedMessages, convertMessage(msg, modelCfg)...) +}, +``` + +### 3. Message Type Mapping + +| `agent.Message` field | `exec.LLMMessage` field | +|---|---| +| `Type` → `MessageTypeUser` | `Role` → `exec.RoleUser` | +| `Type` → `MessageTypeAssistant` | `Role` → `exec.RoleAssistant` | +| `Content` | `Content` | +| `ToolCalls` | `ToolCalls` (convert `llm.ToolCall` → `exec.ToolCall`) | +| `Usage.PromptTokens` | `Metadata.PromptTokens` | +| `Usage.CompletionTokens` | `Metadata.CompletionTokens` | +| `Usage.TotalTokens` | `Metadata.TotalTokens` | +| `Cost` | `Metadata.Cost` (new field) | +| model config (from `LoopConfig`) | `Metadata.Provider`, `Metadata.Model` | + +### 4. Populate `Cost` in `chat` Executor + +In `internal/runtime/builtin/chat/executor.go`, `createResponseMetadata()` (line 757) already builds `LLMMessageMetadata` with token counts. Once the `Cost` field exists, populate it from usage and model pricing. + +## Shared-Nothing Worker Compatibility + +No changes needed. `DAGRunStatusProto` is a JSON-serialized `DAGRunStatus` struct. `node.ChatMessages` is already part of this struct. The runtime's `ChatMessageHandler` capture (`internal/runtime/node.go:192-213`) already calls `SetContext` before execution and `GetMessages` after. Any executor implementing the interface gets persistence in both local and distributed modes automatically: + +- Local: `node.ChatMessages` → `DAGRunStatus` → `WriteStepMessages()` +- Worker: `DAGRunStatus` → `ReportStatus()` gRPC → coordinator `persistChatMessages()` + +## Referenced Files + +| File | Change | +|------|--------| +| `internal/core/exec/messages.go:49` | Add `Cost` to `LLMMessageMetadata` | +| `internal/runtime/builtin/agentstep/executor.go:38-44` | Add `contextMessages`/`savedMessages` fields, implement `ChatMessageHandler` | +| `internal/runtime/builtin/agentstep/executor.go:215` | Convert `agent.Message` → `exec.LLMMessage` in `RecordMessage` callback | +| `internal/runtime/builtin/chat/executor.go:757-768` | Populate new `Cost` field in `createResponseMetadata()` (deferred — see RFC 021) | +| `internal/runtime/executor/executor.go:96` | `ChatMessageHandler` interface (no changes, already exists) | +| `internal/runtime/node.go:192-213` | Runtime capture logic (no changes, already works) | + +## Out of Scope + +- Cost aggregation, dashboard, or API endpoint (RFC 021) +- Changes to `agent.Loop` or `SessionManager` cost calculation +- Cost computation formula or model pricing tables