diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go index 138af470aa93..c80c5005c8f2 100644 --- a/dashboard/app/ai.go +++ b/dashboard/app/ai.go @@ -64,20 +64,23 @@ type uiAIResult struct { } type uiAITrajectorySpan struct { - Started time.Time - Seq int64 - Nesting int64 - Type string - Name string - Model string - Duration time.Duration - Error string - Args string - Results string - Instruction string - Prompt string - Reply string - Thoughts string + Started time.Time + Seq int64 + Nesting int64 + Type string + Name string + Model string + Duration time.Duration + Error string + Args string + Results string + Instruction string + Prompt string + Reply string + Thoughts string + InputTokens int + OutputTokens int + OutputThoughtsTokens int } func handleAIJobsPage(ctx context.Context, w http.ResponseWriter, r *http.Request) error { @@ -233,20 +236,23 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan duration = span.Finished.Time.Sub(span.Started) } res = append(res, &uiAITrajectorySpan{ - Started: span.Started, - Seq: span.Seq, - Nesting: span.Nesting, - Type: span.Type, - Name: span.Name, - Model: span.Model, - Duration: duration, - Error: nullString(span.Error), - Args: nullJSON(span.Args), - Results: nullJSON(span.Results), - Instruction: nullString(span.Instruction), - Prompt: nullString(span.Prompt), - Reply: nullString(span.Reply), - Thoughts: nullString(span.Thoughts), + Started: span.Started, + Seq: span.Seq, + Nesting: span.Nesting, + Type: span.Type, + Name: span.Name, + Model: span.Model, + Duration: duration, + Error: nullString(span.Error), + Args: nullJSON(span.Args), + Results: nullJSON(span.Results), + Instruction: nullString(span.Instruction), + Prompt: nullString(span.Prompt), + Reply: nullString(span.Reply), + Thoughts: nullString(span.Thoughts), + InputTokens: nullInt64(span.InputTokens), + OutputTokens: nullInt64(span.OutputTokens), + OutputThoughtsTokens: nullInt64(span.OutputThoughtsTokens), }) } return res @@ -642,3 +648,10 @@ func nullJSON(v spanner.NullJSON) string { } return fmt.Sprint(v.Value) } + +func nullInt64(v spanner.NullInt64) int { + if !v.Valid { + return 0 + } + return int(v.Int64) +} diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go index 872a70aceaa0..4b73a5c0ae20 100644 --- a/dashboard/app/aidb/crud.go +++ b/dashboard/app/aidb/crud.go @@ -173,21 +173,24 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa } defer client.Close() ent := TrajectorySpan{ - JobID: jobID, - Seq: int64(span.Seq), - Nesting: int64(span.Nesting), - Type: string(span.Type), - Name: span.Name, - Model: span.Model, - Started: span.Started, - Finished: toNullTime(span.Finished), - Error: toNullString(span.Error), - Args: toNullJSON(span.Args), - Results: toNullJSON(span.Results), - Instruction: toNullString(span.Instruction), - Prompt: toNullString(span.Prompt), - Reply: toNullString(span.Reply), - Thoughts: toNullString(span.Thoughts), + JobID: jobID, + Seq: int64(span.Seq), + Nesting: int64(span.Nesting), + Type: string(span.Type), + Name: span.Name, + Model: span.Model, + Started: span.Started, + Finished: toNullTime(span.Finished), + Error: toNullString(span.Error), + Args: toNullJSON(span.Args), + Results: toNullJSON(span.Results), + Instruction: toNullString(span.Instruction), + Prompt: toNullString(span.Prompt), + Reply: toNullString(span.Reply), + Thoughts: toNullString(span.Thoughts), + InputTokens: toNullInt64(span.InputTokens), + OutputTokens: toNullInt64(span.OutputTokens), + OutputThoughtsTokens: toNullInt64(span.OutputThoughtsTokens), } mut, err := spanner.InsertOrUpdateStruct("TrajectorySpans", ent) if err != nil { @@ -290,3 +293,10 @@ func toNullString(v string) spanner.NullString { } return spanner.NullString{StringVal: v, Valid: true} } + +func toNullInt64(v int) spanner.NullInt64 { + if v == 0 { + return spanner.NullInt64{} + } + return spanner.NullInt64{Int64: int64(v), Valid: true} +} diff --git a/dashboard/app/aidb/entities.go b/dashboard/app/aidb/entities.go index 0a3e7b16419b..4e99f14ca33a 100644 --- a/dashboard/app/aidb/entities.go +++ b/dashboard/app/aidb/entities.go @@ -38,18 +38,21 @@ type Job struct { type TrajectorySpan struct { JobID string // The following fields correspond one-to-one to trajectory.Span fields (add field comments there). - Seq int64 - Nesting int64 - Type string - Name string - Model string - Started time.Time - Finished spanner.NullTime - Error spanner.NullString - Args spanner.NullJSON - Results spanner.NullJSON - Instruction spanner.NullString - Prompt spanner.NullString - Reply spanner.NullString - Thoughts spanner.NullString + Seq int64 + Nesting int64 + Type string + Name string + Model string + Started time.Time + Finished spanner.NullTime + Error spanner.NullString + Args spanner.NullJSON + Results spanner.NullJSON + Instruction spanner.NullString + Prompt spanner.NullString + Reply spanner.NullString + Thoughts spanner.NullString + InputTokens spanner.NullInt64 + OutputTokens spanner.NullInt64 + OutputThoughtsTokens spanner.NullInt64 } diff --git a/dashboard/app/aidb/migrations/5_add_trajectory_tokens.down.sql b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.down.sql new file mode 100644 index 000000000000..840293f60ca3 --- /dev/null +++ b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE TrajectorySpans DROP COLUMN InputTokens; +ALTER TABLE TrajectorySpans DROP COLUMN OutputTokens; +ALTER TABLE TrajectorySpans DROP COLUMN OutputThoughtsTokens; diff --git a/dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql new file mode 100644 index 000000000000..e80ede96ef0f --- /dev/null +++ b/dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE TrajectorySpans ADD COLUMN InputTokens INT64; +ALTER TABLE TrajectorySpans ADD COLUMN OutputTokens INT64; +ALTER TABLE TrajectorySpans ADD COLUMN OutputThoughtsTokens INT64; diff --git a/dashboard/app/templates/ai_job.html b/dashboard/app/templates/ai_job.html index f8f2b82bd986..e1971ac3dcc9 100644 --- a/dashboard/app/templates/ai_job.html +++ b/dashboard/app/templates/ai_job.html @@ -86,6 +86,13 @@ {{if $span.Reply}} Reply:
{{$span.Reply}}

{{end}} + {{if $span.InputTokens}} + Tokens:
+							input: {{$span.InputTokens}}
+							output: {{$span.OutputTokens}}
+							thoughts: {{$span.OutputThoughtsTokens}}
+						

+ {{end}} {{if $span.Thoughts}} Thoughts:
{{$span.Thoughts}}

{{end}} diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 405498800b77..ee83b541f321 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map state: maps.Clone(inputs), onEvent: onEvent, } + defer c.close() if s := ctx.Value(stubContextKey); s != nil { c.stubContext = *s.(*stubContext) @@ -143,10 +144,17 @@ var ( createClientOnce sync.Once createClientErr error client *genai.Client - modelList = make(map[string]bool) + modelList = make(map[string]*modelInfo) stubContextKey = contextKeyType(1) ) +type modelInfo struct { + Thinking bool + MaxTemperature float32 + InputTokenLimit int + OutputTokenLimit int +} + func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { const modelPrefix = "models/" @@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte createClientErr = err return } - modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking + if !slices.Contains(m.SupportedActions, "generateContent") || + strings.Contains(m.Name, "-image") || + strings.Contains(m.Name, "-audio") { + continue + } + modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{ + Thinking: m.Thinking, + MaxTemperature: m.MaxTemperature, + InputTokenLimit: int(m.InputTokenLimit), + OutputTokenLimit: int(m.OutputTokenLimit), + } } }) if createClientErr != nil { return nil, createClientErr } - thinking, ok := modelList[model] - if !ok { + info := modelList[model] + if info == nil { models := slices.Collect(maps.Keys(modelList)) slices.Sort(models) return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } - if thinking { + *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature) + if info.Thinking { // Don't alter the original object (that may affect request caching). cfgCopy := *cfg cfg = &cfgCopy diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 67b51dd7916a..2a03d82d0c9d 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -187,6 +187,9 @@ func TestWorkflow(t *testing.T) { }, }, }, + { + Text: "Some non-thoughts reply along with tool calls", + }, { Text: "I am thinking I need to call some tools", Thought: true, @@ -321,7 +324,7 @@ func TestToolMisbehavior(t *testing.T) { &LLMAgent{ Name: "smarty", Model: "model", - Temperature: 1, + Temperature: 0.5, Reply: "Reply", Outputs: LLMOutputs[struct { diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index f60a19a8336b..692443ac3fcf 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -9,6 +9,8 @@ import ( "maps" "net/http" "reflect" + "regexp" + "strconv" "strings" "time" @@ -33,7 +35,7 @@ type LLMAgent struct { // Value that controls the degree of randomness in token selection. // Lower temperatures are good for prompts that require a less open-ended or creative response, // while higher temperatures can lead to more diverse or creative results. - // Must be assigned a float32 value in the range [0, 2]. + // Must be assigned a number in the range [0, 2]. Temperature any // If set, the agent will generate that many candidates and the outputs will be arrays // instead of scalars. @@ -190,8 +192,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma if err != nil { return "", nil, ctx.finishSpan(reqSpan, err) } - reply, thoughts, calls, respErr := a.parseResponse(resp) - reqSpan.Thoughts = thoughts + reply, calls, respErr := a.parseResponse(resp, reqSpan) if err := ctx.finishSpan(reqSpan, respErr); err != nil { return "", nil, err } @@ -243,7 +244,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m } return &genai.GenerateContentConfig{ ResponseModalities: []string{"TEXT"}, - Temperature: genai.Ptr(a.Temperature.(float32)), + Temperature: genai.Ptr(float32(a.Temperature.(float64))), SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser), Tools: tools, }, instruction, toolMap @@ -300,8 +301,8 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai return responses, outputs, nil } -func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( - reply, thoughts string, calls []*genai.FunctionCall, err error) { +func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *trajectory.Span) ( + reply string, calls []*genai.FunctionCall, err error) { if len(resp.Candidates) == 0 || resp.Candidates[0] == nil { err = fmt.Errorf("empty model response") if resp.PromptFeedback != nil { @@ -320,6 +321,13 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( err = fmt.Errorf("unexpected reply fields (%+v)", *candidate) return } + if resp.UsageMetadata != nil { + // We add ToolUsePromptTokenCount just in case, but Gemini does not use/set it. + span.InputTokens = int(resp.UsageMetadata.PromptTokenCount) + + int(resp.UsageMetadata.ToolUsePromptTokenCount) + span.OutputTokens = int(resp.UsageMetadata.CandidatesTokenCount) + span.OutputThoughtsTokens = int(resp.UsageMetadata.ThoughtsTokenCount) + } for _, part := range candidate.Content.Parts { // We don't expect to receive these now. if part.VideoMetadata != nil || part.InlineData != nil || @@ -331,7 +339,7 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( if part.FunctionCall != nil { calls = append(calls, part.FunctionCall) } else if part.Thought { - thoughts += part.Text + span.Thoughts += part.Text } else { reply += part.Text } @@ -339,6 +347,11 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( if strings.TrimSpace(reply) == "" { reply = "" } + // If there is any reply along with tool calls, append it to thoughts. + // Otherwise it won't show up in the trajectory anywhere. + if len(calls) != 0 && reply != "" { + span.Thoughts += "\n" + reply + } return } @@ -355,14 +368,22 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi continue } if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests && - strings.Contains(apiErr.Message, "Quota exceeded for metric") && - strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { - return resp, &modelQuotaError{ctx.modelName(a.Model)} + strings.Contains(apiErr.Message, "Quota exceeded for metric") { + if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil { + sec, _ := strconv.Atoi(match[1]) + time.Sleep(time.Duration(sec+1) * time.Second) + continue + } + if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { + return resp, &modelQuotaError{ctx.modelName(a.Model)} + } } return resp, err } } +var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]") + func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig, req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { type Cached struct { @@ -389,10 +410,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) { ctx.requireNotEmpty(a.Name, "Model", a.Model) ctx.requireNotEmpty(a.Name, "Reply", a.Reply) if temp, ok := a.Temperature.(int); ok { - a.Temperature = float32(temp) + a.Temperature = float64(temp) } - if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 { - ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]") + if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 { + ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]") } if a.Candidates < 0 || a.Candidates > 100 { ctx.errorf(a.Name, "Candidates must be in the range [0, 100]") diff --git a/pkg/aflow/testdata/TestToolMisbehavior.llm.json b/pkg/aflow/testdata/TestToolMisbehavior.llm.json index 370954d981fc..6c6478020e83 100644 --- a/pkg/aflow/testdata/TestToolMisbehavior.llm.json +++ b/pkg/aflow/testdata/TestToolMisbehavior.llm.json @@ -10,7 +10,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -132,7 +132,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -367,7 +367,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -618,7 +618,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -915,7 +915,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ diff --git a/pkg/aflow/testdata/TestWorkflow.llm.json b/pkg/aflow/testdata/TestWorkflow.llm.json index 681552571068..d7e490daaf93 100644 --- a/pkg/aflow/testdata/TestWorkflow.llm.json +++ b/pkg/aflow/testdata/TestWorkflow.llm.json @@ -320,6 +320,9 @@ "name": "tool2" } }, + { + "text": "Some non-thoughts reply along with tool calls" + }, { "text": "I am thinking I need to call some tools", "thought": true @@ -523,6 +526,9 @@ "name": "tool2" } }, + { + "text": "Some non-thoughts reply along with tool calls" + }, { "text": "I am thinking I need to call some tools", "thought": true diff --git a/pkg/aflow/testdata/TestWorkflow.trajectory.json b/pkg/aflow/testdata/TestWorkflow.trajectory.json index 3a4859bbf154..ab3e4989983f 100644 --- a/pkg/aflow/testdata/TestWorkflow.trajectory.json +++ b/pkg/aflow/testdata/TestWorkflow.trajectory.json @@ -51,7 +51,7 @@ "Model": "model1", "Started": "0001-01-01T00:00:05Z", "Finished": "0001-01-01T00:00:06Z", - "Thoughts": "I am thinking I need to call some tools" + "Thoughts": "I am thinking I need to call some tools\nSome non-thoughts reply along with tool calls" }, { "Seq": 4, diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go index ad22b018e7cb..fc63235cfeb1 100644 --- a/pkg/aflow/trajectory/trajectory.go +++ b/pkg/aflow/trajectory/trajectory.go @@ -38,6 +38,12 @@ type Span struct { // LLM invocation. Thoughts string `json:",omitzero"` + + // For details see: + // https://pkg.go.dev/google.golang.org/genai#GenerateContentResponseUsageMetadata + InputTokens int `json:",omitzero"` + OutputTokens int `json:",omitzero"` + OutputThoughtsTokens int `json:",omitzero"` } type SpanType string @@ -89,6 +95,8 @@ func (span *Span) String() string { } fmt.Fprintf(sb, "reply:\n%v\n", span.Reply) case SpanLLM: + fmt.Fprintf(sb, "tokens: input=%v output=%v thoughts=%v\n", + span.InputTokens, span.OutputTokens, span.OutputThoughtsTokens) if span.Thoughts != "" { fmt.Fprintf(sb, "thoughts:\n%v\n", span.Thoughts) }