Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/core/exec/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type LLMMessageMetadata struct {
CompletionTokens int `json:"completionTokens,omitempty"`
// TotalTokens is the sum of prompt and completion tokens.
TotalTokens int `json:"totalTokens,omitempty"`
// Cost is the estimated USD cost for this API call.
Cost float64 `json:"cost,omitempty"`
}

// ToolDefinition represents a tool that was available to the LLM.
Expand Down
46 changes: 46 additions & 0 deletions internal/core/exec/messages_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package exec

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDeduplicateSystemMessages(t *testing.T) {
Expand Down Expand Up @@ -68,3 +70,47 @@ func TestDeduplicateSystemMessages(t *testing.T) {
})
}
}

func TestLLMMessageMetadata_CostJSONRoundTrip(t *testing.T) {
t.Run("CostSerializesWithKey", func(t *testing.T) {
meta := LLMMessageMetadata{
Provider: "openai",
Model: "gpt-4",
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
Cost: 0.0042,
}

data, err := json.Marshal(meta)
require.NoError(t, err)

// Verify "cost" key is present in JSON
assert.Contains(t, string(data), `"cost":`)

var decoded LLMMessageMetadata
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)

assert.Equal(t, meta, decoded)
assert.InDelta(t, 0.0042, decoded.Cost, 1e-9)
})

t.Run("ZeroCostOmittedFromJSON", func(t *testing.T) {
meta := LLMMessageMetadata{
Provider: "openai",
Model: "gpt-4",
}

data, err := json.Marshal(meta)
require.NoError(t, err)

// "cost" should be omitted when zero (omitempty)
assert.NotContains(t, string(data), `"cost"`)

var decoded LLMMessageMetadata
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.InDelta(t, 0.0, decoded.Cost, 1e-9)
})
}
122 changes: 122 additions & 0 deletions internal/runtime/builtin/agentstep/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package agentstep

import (
"github.com/dagu-org/dagu/internal/agent"
"github.com/dagu-org/dagu/internal/core/exec"
"github.com/dagu-org/dagu/internal/llm"
)

// convertMessage converts an agent.Message to one or more exec.LLMMessage values.
// User messages with tool results expand to one message per result (Role=Tool).
func convertMessage(msg agent.Message, modelCfg *agent.ModelConfig) []exec.LLMMessage {
switch msg.Type {
case agent.MessageTypeAssistant:
return []exec.LLMMessage{convertAssistantMessage(msg, modelCfg)}

case agent.MessageTypeUser:
if len(msg.ToolResults) > 0 {
return convertToolResultMessages(msg)
}
return []exec.LLMMessage{{
Role: exec.RoleUser,
Content: msg.Content,
}}

case agent.MessageTypeError:
return []exec.LLMMessage{{
Role: exec.RoleAssistant,
Content: msg.Content,
}}

default:
return nil
}
}

// convertAssistantMessage converts an assistant agent.Message to an exec.LLMMessage.
func convertAssistantMessage(msg agent.Message, modelCfg *agent.ModelConfig) exec.LLMMessage {
m := exec.LLMMessage{
Role: exec.RoleAssistant,
Content: msg.Content,
}

// Convert tool calls if present.
if len(msg.ToolCalls) > 0 {
m.ToolCalls = make([]exec.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
m.ToolCalls[i] = exec.ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: exec.ToolCallFunction{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
}
}
}

// Build metadata with provider/model and optional usage/cost.
metadata := &exec.LLMMessageMetadata{
Provider: modelCfg.Provider,
Model: modelCfg.Model,
}
if msg.Usage != nil {
metadata.PromptTokens = msg.Usage.PromptTokens
metadata.CompletionTokens = msg.Usage.CompletionTokens
metadata.TotalTokens = msg.Usage.TotalTokens
}
if msg.Cost != nil {
metadata.Cost = *msg.Cost
}
m.Metadata = metadata

return m
}

// contextToLLMHistory converts context messages to llm.Message for LoopConfig.History.
// System messages are filtered out since the loop handles system prompt separately.
func contextToLLMHistory(msgs []exec.LLMMessage) []llm.Message {
if len(msgs) == 0 {
return nil
}
var result []llm.Message
for _, msg := range msgs {
if msg.Role == exec.RoleSystem {
continue
}
m := llm.Message{
Role: llm.Role(msg.Role),
Content: msg.Content,
ToolCallID: msg.ToolCallID,
}
if len(msg.ToolCalls) > 0 {
m.ToolCalls = make([]llm.ToolCall, len(msg.ToolCalls))
for j, tc := range msg.ToolCalls {
m.ToolCalls[j] = llm.ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: llm.ToolCallFunction{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
}
}
}
result = append(result, m)
}
return result
}

// convertToolResultMessages converts a user message with tool results
// into one exec.LLMMessage per tool result (Role=Tool).
func convertToolResultMessages(msg agent.Message) []exec.LLMMessage {
msgs := make([]exec.LLMMessage, len(msg.ToolResults))
for i, tr := range msg.ToolResults {
msgs[i] = exec.LLMMessage{
Role: exec.RoleTool,
Content: tr.Content,
ToolCallID: tr.ToolCallID,
}
}
return msgs
}
178 changes: 178 additions & 0 deletions internal/runtime/builtin/agentstep/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package agentstep

import (
"testing"

"github.com/dagu-org/dagu/internal/agent"
"github.com/dagu-org/dagu/internal/core/exec"
"github.com/dagu-org/dagu/internal/llm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func testModelConfig() *agent.ModelConfig {
return &agent.ModelConfig{
Provider: "openai",
Model: "gpt-4",
}
}

func TestConvertMessage_AssistantWithUsageAndCost(t *testing.T) {
t.Parallel()

cost := 0.0042
msg := agent.Message{
Type: agent.MessageTypeAssistant,
Content: "hello world",
Usage: &llm.Usage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
},
Cost: &cost,
}

result := convertMessage(msg, testModelConfig())

require.Len(t, result, 1)
m := result[0]
assert.Equal(t, exec.RoleAssistant, m.Role)
assert.Equal(t, "hello world", m.Content)
require.NotNil(t, m.Metadata)
assert.Equal(t, "openai", m.Metadata.Provider)
assert.Equal(t, "gpt-4", m.Metadata.Model)
assert.Equal(t, 100, m.Metadata.PromptTokens)
assert.Equal(t, 50, m.Metadata.CompletionTokens)
assert.Equal(t, 150, m.Metadata.TotalTokens)
assert.InDelta(t, 0.0042, m.Metadata.Cost, 1e-9)
}

func TestConvertMessage_AssistantWithToolCalls(t *testing.T) {
t.Parallel()

msg := agent.Message{
Type: agent.MessageTypeAssistant,
Content: "",
ToolCalls: []llm.ToolCall{
{
ID: "call_1",
Type: "function",
Function: llm.ToolCallFunction{
Name: "bash",
Arguments: `{"command":"ls"}`,
},
},
{
ID: "call_2",
Type: "function",
Function: llm.ToolCallFunction{
Name: "read",
Arguments: `{"path":"/tmp/test"}`,
},
},
},
}

result := convertMessage(msg, testModelConfig())

require.Len(t, result, 1)
m := result[0]
assert.Equal(t, exec.RoleAssistant, m.Role)
require.Len(t, m.ToolCalls, 2)

assert.Equal(t, "call_1", m.ToolCalls[0].ID)
assert.Equal(t, "function", m.ToolCalls[0].Type)
assert.Equal(t, "bash", m.ToolCalls[0].Function.Name)
assert.Equal(t, `{"command":"ls"}`, m.ToolCalls[0].Function.Arguments)

assert.Equal(t, "call_2", m.ToolCalls[1].ID)
assert.Equal(t, "read", m.ToolCalls[1].Function.Name)
}

func TestConvertMessage_UserNoToolResults(t *testing.T) {
t.Parallel()

msg := agent.Message{
Type: agent.MessageTypeUser,
Content: "user input",
}

result := convertMessage(msg, testModelConfig())

require.Len(t, result, 1)
assert.Equal(t, exec.RoleUser, result[0].Role)
assert.Equal(t, "user input", result[0].Content)
assert.Nil(t, result[0].Metadata)
}

func TestConvertMessage_UserWithToolResults(t *testing.T) {
t.Parallel()

msg := agent.Message{
Type: agent.MessageTypeUser,
ToolResults: []agent.ToolResult{
{ToolCallID: "call_1", Content: "result 1"},
{ToolCallID: "call_2", Content: "result 2"},
},
}

result := convertMessage(msg, testModelConfig())

require.Len(t, result, 2)

assert.Equal(t, exec.RoleTool, result[0].Role)
assert.Equal(t, "result 1", result[0].Content)
assert.Equal(t, "call_1", result[0].ToolCallID)

assert.Equal(t, exec.RoleTool, result[1].Role)
assert.Equal(t, "result 2", result[1].Content)
assert.Equal(t, "call_2", result[1].ToolCallID)
}

func TestConvertMessage_NilUsageAndCost(t *testing.T) {
t.Parallel()

msg := agent.Message{
Type: agent.MessageTypeAssistant,
Content: "response",
}

result := convertMessage(msg, testModelConfig())

require.Len(t, result, 1)
m := result[0]
require.NotNil(t, m.Metadata)
assert.Equal(t, "openai", m.Metadata.Provider)
assert.Equal(t, "gpt-4", m.Metadata.Model)
assert.Equal(t, 0, m.Metadata.PromptTokens)
assert.Equal(t, 0, m.Metadata.CompletionTokens)
assert.Equal(t, 0, m.Metadata.TotalTokens)
assert.InDelta(t, 0.0, m.Metadata.Cost, 1e-9)
}

func TestConvertMessage_ErrorType(t *testing.T) {
t.Parallel()

msg := agent.Message{
Type: agent.MessageTypeError,
Content: "something went wrong",
}

result := convertMessage(msg, testModelConfig())

require.Len(t, result, 1)
assert.Equal(t, exec.RoleAssistant, result[0].Role)
assert.Equal(t, "something went wrong", result[0].Content)
}

func TestConvertMessage_UnknownType(t *testing.T) {
t.Parallel()

msg := agent.Message{
Type: agent.MessageTypeUIAction,
Content: "navigate",
}

result := convertMessage(msg, testModelConfig())
assert.Nil(t, result)
}
Loading
Loading