diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go
index 138af470aa93..c80c5005c8f2 100644
--- a/dashboard/app/ai.go
+++ b/dashboard/app/ai.go
@@ -64,20 +64,23 @@ type uiAIResult struct {
}
type uiAITrajectorySpan struct {
- Started time.Time
- Seq int64
- Nesting int64
- Type string
- Name string
- Model string
- Duration time.Duration
- Error string
- Args string
- Results string
- Instruction string
- Prompt string
- Reply string
- Thoughts string
+ Started time.Time
+ Seq int64
+ Nesting int64
+ Type string
+ Name string
+ Model string
+ Duration time.Duration
+ Error string
+ Args string
+ Results string
+ Instruction string
+ Prompt string
+ Reply string
+ Thoughts string
+ InputTokens int
+ OutputTokens int
+ OutputThoughtsTokens int
}
func handleAIJobsPage(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
@@ -233,20 +236,23 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan
duration = span.Finished.Time.Sub(span.Started)
}
res = append(res, &uiAITrajectorySpan{
- Started: span.Started,
- Seq: span.Seq,
- Nesting: span.Nesting,
- Type: span.Type,
- Name: span.Name,
- Model: span.Model,
- Duration: duration,
- Error: nullString(span.Error),
- Args: nullJSON(span.Args),
- Results: nullJSON(span.Results),
- Instruction: nullString(span.Instruction),
- Prompt: nullString(span.Prompt),
- Reply: nullString(span.Reply),
- Thoughts: nullString(span.Thoughts),
+ Started: span.Started,
+ Seq: span.Seq,
+ Nesting: span.Nesting,
+ Type: span.Type,
+ Name: span.Name,
+ Model: span.Model,
+ Duration: duration,
+ Error: nullString(span.Error),
+ Args: nullJSON(span.Args),
+ Results: nullJSON(span.Results),
+ Instruction: nullString(span.Instruction),
+ Prompt: nullString(span.Prompt),
+ Reply: nullString(span.Reply),
+ Thoughts: nullString(span.Thoughts),
+ InputTokens: nullInt64(span.InputTokens),
+ OutputTokens: nullInt64(span.OutputTokens),
+ OutputThoughtsTokens: nullInt64(span.OutputThoughtsTokens),
})
}
return res
@@ -642,3 +648,10 @@ func nullJSON(v spanner.NullJSON) string {
}
return fmt.Sprint(v.Value)
}
+
+func nullInt64(v spanner.NullInt64) int {
+ if !v.Valid {
+ return 0
+ }
+ return int(v.Int64)
+}
diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go
index 872a70aceaa0..4b73a5c0ae20 100644
--- a/dashboard/app/aidb/crud.go
+++ b/dashboard/app/aidb/crud.go
@@ -173,21 +173,24 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa
}
defer client.Close()
ent := TrajectorySpan{
- JobID: jobID,
- Seq: int64(span.Seq),
- Nesting: int64(span.Nesting),
- Type: string(span.Type),
- Name: span.Name,
- Model: span.Model,
- Started: span.Started,
- Finished: toNullTime(span.Finished),
- Error: toNullString(span.Error),
- Args: toNullJSON(span.Args),
- Results: toNullJSON(span.Results),
- Instruction: toNullString(span.Instruction),
- Prompt: toNullString(span.Prompt),
- Reply: toNullString(span.Reply),
- Thoughts: toNullString(span.Thoughts),
+ JobID: jobID,
+ Seq: int64(span.Seq),
+ Nesting: int64(span.Nesting),
+ Type: string(span.Type),
+ Name: span.Name,
+ Model: span.Model,
+ Started: span.Started,
+ Finished: toNullTime(span.Finished),
+ Error: toNullString(span.Error),
+ Args: toNullJSON(span.Args),
+ Results: toNullJSON(span.Results),
+ Instruction: toNullString(span.Instruction),
+ Prompt: toNullString(span.Prompt),
+ Reply: toNullString(span.Reply),
+ Thoughts: toNullString(span.Thoughts),
+ InputTokens: toNullInt64(span.InputTokens),
+ OutputTokens: toNullInt64(span.OutputTokens),
+ OutputThoughtsTokens: toNullInt64(span.OutputThoughtsTokens),
}
mut, err := spanner.InsertOrUpdateStruct("TrajectorySpans", ent)
if err != nil {
@@ -290,3 +293,10 @@ func toNullString(v string) spanner.NullString {
}
return spanner.NullString{StringVal: v, Valid: true}
}
+
+func toNullInt64(v int) spanner.NullInt64 {
+ if v == 0 {
+ return spanner.NullInt64{}
+ }
+ return spanner.NullInt64{Int64: int64(v), Valid: true}
+}
diff --git a/dashboard/app/aidb/entities.go b/dashboard/app/aidb/entities.go
index 0a3e7b16419b..4e99f14ca33a 100644
--- a/dashboard/app/aidb/entities.go
+++ b/dashboard/app/aidb/entities.go
@@ -38,18 +38,21 @@ type Job struct {
type TrajectorySpan struct {
JobID string
// The following fields correspond one-to-one to trajectory.Span fields (add field comments there).
- Seq int64
- Nesting int64
- Type string
- Name string
- Model string
- Started time.Time
- Finished spanner.NullTime
- Error spanner.NullString
- Args spanner.NullJSON
- Results spanner.NullJSON
- Instruction spanner.NullString
- Prompt spanner.NullString
- Reply spanner.NullString
- Thoughts spanner.NullString
+ Seq int64
+ Nesting int64
+ Type string
+ Name string
+ Model string
+ Started time.Time
+ Finished spanner.NullTime
+ Error spanner.NullString
+ Args spanner.NullJSON
+ Results spanner.NullJSON
+ Instruction spanner.NullString
+ Prompt spanner.NullString
+ Reply spanner.NullString
+ Thoughts spanner.NullString
+ InputTokens spanner.NullInt64
+ OutputTokens spanner.NullInt64
+ OutputThoughtsTokens spanner.NullInt64
}
diff --git a/dashboard/app/aidb/migrations/5_add_trajectory_tokens.down.sql b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.down.sql
new file mode 100644
index 000000000000..840293f60ca3
--- /dev/null
+++ b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.down.sql
@@ -0,0 +1,3 @@
+ALTER TABLE TrajectorySpans DROP COLUMN InputTokens;
+ALTER TABLE TrajectorySpans DROP COLUMN OutputTokens;
+ALTER TABLE TrajectorySpans DROP COLUMN OutputThoughtsTokens;
diff --git a/dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql
new file mode 100644
index 000000000000..e80ede96ef0f
--- /dev/null
+++ b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql
@@ -0,0 +1,3 @@
+ALTER TABLE TrajectorySpans ADD COLUMN InputTokens INT64;
+ALTER TABLE TrajectorySpans ADD COLUMN OutputTokens INT64;
+ALTER TABLE TrajectorySpans ADD COLUMN OutputThoughtsTokens INT64;
diff --git a/dashboard/app/templates/ai_job.html b/dashboard/app/templates/ai_job.html
index f8f2b82bd986..e1971ac3dcc9 100644
--- a/dashboard/app/templates/ai_job.html
+++ b/dashboard/app/templates/ai_job.html
@@ -86,6 +86,13 @@
{{if $span.Reply}}
Reply:
{{end}}
+ {{if $span.InputTokens}}
+ Tokens:
+ input: {{$span.InputTokens}}
+ output: {{$span.OutputTokens}}
+ thoughts: {{$span.OutputThoughtsTokens}}
+
+ {{end}}
{{if $span.Thoughts}}
Thoughts:
{{end}}
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 405498800b77..ee83b541f321 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map
state: maps.Clone(inputs),
onEvent: onEvent,
}
+
defer c.close()
if s := ctx.Value(stubContextKey); s != nil {
c.stubContext = *s.(*stubContext)
@@ -143,10 +144,17 @@ var (
createClientOnce sync.Once
createClientErr error
client *genai.Client
- modelList = make(map[string]bool)
+ modelList = make(map[string]*modelInfo)
stubContextKey = contextKeyType(1)
)
+type modelInfo struct {
+ Thinking bool
+ MaxTemperature float32
+ InputTokenLimit int
+ OutputTokenLimit int
+}
+
func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
@@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
createClientErr = err
return
}
- modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
+ if !slices.Contains(m.SupportedActions, "generateContent") ||
+ strings.Contains(m.Name, "-image") ||
+ strings.Contains(m.Name, "-audio") {
+ continue
+ }
+ modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{
+ Thinking: m.Thinking,
+ MaxTemperature: m.MaxTemperature,
+ InputTokenLimit: int(m.InputTokenLimit),
+ OutputTokenLimit: int(m.OutputTokenLimit),
+ }
}
})
if createClientErr != nil {
return nil, createClientErr
}
- thinking, ok := modelList[model]
- if !ok {
+ info := modelList[model]
+ if info == nil {
models := slices.Collect(maps.Keys(modelList))
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- if thinking {
+ *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature)
+ if info.Thinking {
// Don't alter the original object (that may affect request caching).
cfgCopy := *cfg
cfg = &cfgCopy
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 67b51dd7916a..2a03d82d0c9d 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -187,6 +187,9 @@ func TestWorkflow(t *testing.T) {
},
},
},
+ {
+ Text: "Some non-thoughts reply along with tool calls",
+ },
{
Text: "I am thinking I need to call some tools",
Thought: true,
@@ -321,7 +324,7 @@ func TestToolMisbehavior(t *testing.T) {
&LLMAgent{
Name: "smarty",
Model: "model",
- Temperature: 1,
+ Temperature: 0.5,
Reply: "Reply",
Outputs: LLMOutputs[struct {
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index f60a19a8336b..692443ac3fcf 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -9,6 +9,8 @@ import (
"maps"
"net/http"
"reflect"
+ "regexp"
+ "strconv"
"strings"
"time"
@@ -33,7 +35,7 @@ type LLMAgent struct {
// Value that controls the degree of randomness in token selection.
// Lower temperatures are good for prompts that require a less open-ended or creative response,
// while higher temperatures can lead to more diverse or creative results.
- // Must be assigned a float32 value in the range [0, 2].
+ // Must be assigned a number in the range [0, 2].
Temperature any
// If set, the agent will generate that many candidates and the outputs will be arrays
// instead of scalars.
@@ -190,8 +192,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err != nil {
return "", nil, ctx.finishSpan(reqSpan, err)
}
- reply, thoughts, calls, respErr := a.parseResponse(resp)
- reqSpan.Thoughts = thoughts
+ reply, calls, respErr := a.parseResponse(resp, reqSpan)
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
return "", nil, err
}
@@ -243,7 +244,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m
}
return &genai.GenerateContentConfig{
ResponseModalities: []string{"TEXT"},
- Temperature: genai.Ptr(a.Temperature.(float32)),
+ Temperature: genai.Ptr(float32(a.Temperature.(float64))),
SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser),
Tools: tools,
}, instruction, toolMap
@@ -300,8 +301,8 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
return responses, outputs, nil
}
-func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
- reply, thoughts string, calls []*genai.FunctionCall, err error) {
+func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *trajectory.Span) (
+ reply string, calls []*genai.FunctionCall, err error) {
if len(resp.Candidates) == 0 || resp.Candidates[0] == nil {
err = fmt.Errorf("empty model response")
if resp.PromptFeedback != nil {
@@ -320,6 +321,13 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
err = fmt.Errorf("unexpected reply fields (%+v)", *candidate)
return
}
+ if resp.UsageMetadata != nil {
+ // We add ToolUsePromptTokenCount just in case, but Gemini does not use/set it.
+ span.InputTokens = int(resp.UsageMetadata.PromptTokenCount) +
+ int(resp.UsageMetadata.ToolUsePromptTokenCount)
+ span.OutputTokens = int(resp.UsageMetadata.CandidatesTokenCount)
+ span.OutputThoughtsTokens = int(resp.UsageMetadata.ThoughtsTokenCount)
+ }
for _, part := range candidate.Content.Parts {
// We don't expect to receive these now.
if part.VideoMetadata != nil || part.InlineData != nil ||
@@ -331,7 +339,7 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
if part.FunctionCall != nil {
calls = append(calls, part.FunctionCall)
} else if part.Thought {
- thoughts += part.Text
+ span.Thoughts += part.Text
} else {
reply += part.Text
}
@@ -339,6 +347,11 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
if strings.TrimSpace(reply) == "" {
reply = ""
}
+ // If there is any reply along with tool calls, append it to thoughts.
+ // Otherwise it won't show up in the trajectory anywhere.
+ if len(calls) != 0 && reply != "" {
+ span.Thoughts += "\n" + reply
+ }
return
}
@@ -355,14 +368,22 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
continue
}
if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests &&
- strings.Contains(apiErr.Message, "Quota exceeded for metric") &&
- strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
- return resp, &modelQuotaError{ctx.modelName(a.Model)}
+ strings.Contains(apiErr.Message, "Quota exceeded for metric") {
+ if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil {
+ sec, _ := strconv.Atoi(match[1])
+ time.Sleep(time.Duration(sec+1) * time.Second)
+ continue
+ }
+ if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
+ return resp, &modelQuotaError{ctx.modelName(a.Model)}
+ }
}
return resp, err
}
}
+var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]")
+
func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig,
req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) {
type Cached struct {
@@ -389,10 +410,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) {
ctx.requireNotEmpty(a.Name, "Model", a.Model)
ctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
- a.Temperature = float32(temp)
+ a.Temperature = float64(temp)
}
- if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
- ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
+ if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 {
+ ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]")
}
if a.Candidates < 0 || a.Candidates > 100 {
ctx.errorf(a.Name, "Candidates must be in the range [0, 100]")
diff --git a/pkg/aflow/testdata/TestToolMisbehavior.llm.json b/pkg/aflow/testdata/TestToolMisbehavior.llm.json
index 370954d981fc..6c6478020e83 100644
--- a/pkg/aflow/testdata/TestToolMisbehavior.llm.json
+++ b/pkg/aflow/testdata/TestToolMisbehavior.llm.json
@@ -10,7 +10,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -132,7 +132,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -367,7 +367,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -618,7 +618,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -915,7 +915,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
diff --git a/pkg/aflow/testdata/TestWorkflow.llm.json b/pkg/aflow/testdata/TestWorkflow.llm.json
index 681552571068..d7e490daaf93 100644
--- a/pkg/aflow/testdata/TestWorkflow.llm.json
+++ b/pkg/aflow/testdata/TestWorkflow.llm.json
@@ -320,6 +320,9 @@
"name": "tool2"
}
},
+ {
+ "text": "Some non-thoughts reply along with tool calls"
+ },
{
"text": "I am thinking I need to call some tools",
"thought": true
@@ -523,6 +526,9 @@
"name": "tool2"
}
},
+ {
+ "text": "Some non-thoughts reply along with tool calls"
+ },
{
"text": "I am thinking I need to call some tools",
"thought": true
diff --git a/pkg/aflow/testdata/TestWorkflow.trajectory.json b/pkg/aflow/testdata/TestWorkflow.trajectory.json
index 3a4859bbf154..ab3e4989983f 100644
--- a/pkg/aflow/testdata/TestWorkflow.trajectory.json
+++ b/pkg/aflow/testdata/TestWorkflow.trajectory.json
@@ -51,7 +51,7 @@
"Model": "model1",
"Started": "0001-01-01T00:00:05Z",
"Finished": "0001-01-01T00:00:06Z",
- "Thoughts": "I am thinking I need to call some tools"
+ "Thoughts": "I am thinking I need to call some tools\nSome non-thoughts reply along with tool calls"
},
{
"Seq": 4,
diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go
index ad22b018e7cb..fc63235cfeb1 100644
--- a/pkg/aflow/trajectory/trajectory.go
+++ b/pkg/aflow/trajectory/trajectory.go
@@ -38,6 +38,12 @@ type Span struct {
// LLM invocation.
Thoughts string `json:",omitzero"`
+
+ // For details see:
+ // https://pkg.go.dev/google.golang.org/genai#GenerateContentResponseUsageMetadata
+ InputTokens int `json:",omitzero"`
+ OutputTokens int `json:",omitzero"`
+ OutputThoughtsTokens int `json:",omitzero"`
}
type SpanType string
@@ -89,6 +95,8 @@ func (span *Span) String() string {
}
fmt.Fprintf(sb, "reply:\n%v\n", span.Reply)
case SpanLLM:
+ fmt.Fprintf(sb, "tokens: input=%v output=%v thoughts=%v\n",
+ span.InputTokens, span.OutputTokens, span.OutputThoughtsTokens)
if span.Thoughts != "" {
fmt.Fprintf(sb, "thoughts:\n%v\n", span.Thoughts)
}