Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions dashboard/app/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,23 @@ type uiAIResult struct {
}

type uiAITrajectorySpan struct {
Started time.Time
Seq int64
Nesting int64
Type string
Name string
Model string
Duration time.Duration
Error string
Args string
Results string
Instruction string
Prompt string
Reply string
Thoughts string
Started time.Time
Seq int64
Nesting int64
Type string
Name string
Model string
Duration time.Duration
Error string
Args string
Results string
Instruction string
Prompt string
Reply string
Thoughts string
InputTokens int
OutputTokens int
OutputThoughtsTokens int
}

func handleAIJobsPage(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
Expand Down Expand Up @@ -233,20 +236,23 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan
duration = span.Finished.Time.Sub(span.Started)
}
res = append(res, &uiAITrajectorySpan{
Started: span.Started,
Seq: span.Seq,
Nesting: span.Nesting,
Type: span.Type,
Name: span.Name,
Model: span.Model,
Duration: duration,
Error: nullString(span.Error),
Args: nullJSON(span.Args),
Results: nullJSON(span.Results),
Instruction: nullString(span.Instruction),
Prompt: nullString(span.Prompt),
Reply: nullString(span.Reply),
Thoughts: nullString(span.Thoughts),
Started: span.Started,
Seq: span.Seq,
Nesting: span.Nesting,
Type: span.Type,
Name: span.Name,
Model: span.Model,
Duration: duration,
Error: nullString(span.Error),
Args: nullJSON(span.Args),
Results: nullJSON(span.Results),
Instruction: nullString(span.Instruction),
Prompt: nullString(span.Prompt),
Reply: nullString(span.Reply),
Thoughts: nullString(span.Thoughts),
InputTokens: nullInt64(span.InputTokens),
OutputTokens: nullInt64(span.OutputTokens),
OutputThoughtsTokens: nullInt64(span.OutputThoughtsTokens),
})
}
return res
Expand Down Expand Up @@ -642,3 +648,10 @@ func nullJSON(v spanner.NullJSON) string {
}
return fmt.Sprint(v.Value)
}

func nullInt64(v spanner.NullInt64) int {
if !v.Valid {
return 0
}
return int(v.Int64)
}
40 changes: 25 additions & 15 deletions dashboard/app/aidb/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,24 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa
}
defer client.Close()
ent := TrajectorySpan{
JobID: jobID,
Seq: int64(span.Seq),
Nesting: int64(span.Nesting),
Type: string(span.Type),
Name: span.Name,
Model: span.Model,
Started: span.Started,
Finished: toNullTime(span.Finished),
Error: toNullString(span.Error),
Args: toNullJSON(span.Args),
Results: toNullJSON(span.Results),
Instruction: toNullString(span.Instruction),
Prompt: toNullString(span.Prompt),
Reply: toNullString(span.Reply),
Thoughts: toNullString(span.Thoughts),
JobID: jobID,
Seq: int64(span.Seq),
Nesting: int64(span.Nesting),
Type: string(span.Type),
Name: span.Name,
Model: span.Model,
Started: span.Started,
Finished: toNullTime(span.Finished),
Error: toNullString(span.Error),
Args: toNullJSON(span.Args),
Results: toNullJSON(span.Results),
Instruction: toNullString(span.Instruction),
Prompt: toNullString(span.Prompt),
Reply: toNullString(span.Reply),
Thoughts: toNullString(span.Thoughts),
InputTokens: toNullInt64(span.InputTokens),
OutputTokens: toNullInt64(span.OutputTokens),
OutputThoughtsTokens: toNullInt64(span.OutputThoughtsTokens),
}
mut, err := spanner.InsertOrUpdateStruct("TrajectorySpans", ent)
if err != nil {
Expand Down Expand Up @@ -290,3 +293,10 @@ func toNullString(v string) spanner.NullString {
}
return spanner.NullString{StringVal: v, Valid: true}
}

func toNullInt64(v int) spanner.NullInt64 {
if v == 0 {
return spanner.NullInt64{}
}
return spanner.NullInt64{Int64: int64(v), Valid: true}
}
31 changes: 17 additions & 14 deletions dashboard/app/aidb/entities.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,21 @@ type Job struct {
type TrajectorySpan struct {
JobID string
// The following fields correspond one-to-one to trajectory.Span fields (add field comments there).
Seq int64
Nesting int64
Type string
Name string
Model string
Started time.Time
Finished spanner.NullTime
Error spanner.NullString
Args spanner.NullJSON
Results spanner.NullJSON
Instruction spanner.NullString
Prompt spanner.NullString
Reply spanner.NullString
Thoughts spanner.NullString
Seq int64
Nesting int64
Type string
Name string
Model string
Started time.Time
Finished spanner.NullTime
Error spanner.NullString
Args spanner.NullJSON
Results spanner.NullJSON
Instruction spanner.NullString
Prompt spanner.NullString
Reply spanner.NullString
Thoughts spanner.NullString
InputTokens spanner.NullInt64
OutputTokens spanner.NullInt64
OutputThoughtsTokens spanner.NullInt64
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE TrajectorySpans DROP COLUMN InputTokens;
ALTER TABLE TrajectorySpans DROP COLUMN OutputTokens;
ALTER TABLE TrajectorySpans DROP COLUMN OutputThoughtsTokens;
3 changes: 3 additions & 0 deletions dashboard/app/aidb/migrations/5_add_trajectory_tokens.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE TrajectorySpans ADD COLUMN InputTokens INT64;
ALTER TABLE TrajectorySpans ADD COLUMN OutputTokens INT64;
ALTER TABLE TrajectorySpans ADD COLUMN OutputThoughtsTokens INT64;
7 changes: 7 additions & 0 deletions dashboard/app/templates/ai_job.html
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@
{{if $span.Reply}}
<b>Reply:</b> <div id="ai_details_div"><pre>{{$span.Reply}}</pre></div><br>
{{end}}
{{if $span.InputTokens}}
<b>Tokens:</b> <div id="ai_details_div"><pre>
input: {{$span.InputTokens}}
output: {{$span.OutputTokens}}
thoughts: {{$span.OutputThoughtsTokens}}
</pre></div><br>
{{end}}
{{if $span.Thoughts}}
<b>Thoughts:</b> <div id="ai_details_div"><pre>{{$span.Thoughts}}</pre></div><br>
{{end}}
Expand Down
29 changes: 24 additions & 5 deletions pkg/aflow/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map
state: maps.Clone(inputs),
onEvent: onEvent,
}

defer c.close()
if s := ctx.Value(stubContextKey); s != nil {
c.stubContext = *s.(*stubContext)
Expand Down Expand Up @@ -143,10 +144,17 @@ var (
createClientOnce sync.Once
createClientErr error
client *genai.Client
modelList = make(map[string]bool)
modelList = make(map[string]*modelInfo)
stubContextKey = contextKeyType(1)
)

type modelInfo struct {
Thinking bool
MaxTemperature float32
InputTokenLimit int
OutputTokenLimit int
}

func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
Expand All @@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
createClientErr = err
return
}
modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
if !slices.Contains(m.SupportedActions, "generateContent") ||
strings.Contains(m.Name, "-image") ||
strings.Contains(m.Name, "-audio") {
continue
}
modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{
Thinking: m.Thinking,
MaxTemperature: m.MaxTemperature,
InputTokenLimit: int(m.InputTokenLimit),
OutputTokenLimit: int(m.OutputTokenLimit),
}
}
})
if createClientErr != nil {
return nil, createClientErr
}
thinking, ok := modelList[model]
if !ok {
info := modelList[model]
if info == nil {
models := slices.Collect(maps.Keys(modelList))
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
if thinking {
*cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature)
if info.Thinking {
// Don't alter the original object (that may affect request caching).
cfgCopy := *cfg
cfg = &cfgCopy
Expand Down
5 changes: 4 additions & 1 deletion pkg/aflow/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ func TestWorkflow(t *testing.T) {
},
},
},
{
Text: "Some non-thoughts reply along with tool calls",
},
{
Text: "I am thinking I need to call some tools",
Thought: true,
Expand Down Expand Up @@ -321,7 +324,7 @@ func TestToolMisbehavior(t *testing.T) {
&LLMAgent{
Name: "smarty",
Model: "model",
Temperature: 1,
Temperature: 0.5,
Reply: "Reply",

Outputs: LLMOutputs[struct {
Expand Down
47 changes: 34 additions & 13 deletions pkg/aflow/llm_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"maps"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
"time"

Expand All @@ -33,7 +35,7 @@ type LLMAgent struct {
// Value that controls the degree of randomness in token selection.
// Lower temperatures are good for prompts that require a less open-ended or creative response,
// while higher temperatures can lead to more diverse or creative results.
// Must be assigned a float32 value in the range [0, 2].
// Must be assigned a number in the range [0, 2].
Temperature any
// If set, the agent will generate that many candidates and the outputs will be arrays
// instead of scalars.
Expand Down Expand Up @@ -190,8 +192,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err != nil {
return "", nil, ctx.finishSpan(reqSpan, err)
}
reply, thoughts, calls, respErr := a.parseResponse(resp)
reqSpan.Thoughts = thoughts
reply, calls, respErr := a.parseResponse(resp, reqSpan)
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -243,7 +244,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m
}
return &genai.GenerateContentConfig{
ResponseModalities: []string{"TEXT"},
Temperature: genai.Ptr(a.Temperature.(float32)),
Temperature: genai.Ptr(float32(a.Temperature.(float64))),
SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser),
Tools: tools,
}, instruction, toolMap
Expand Down Expand Up @@ -300,8 +301,8 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
return responses, outputs, nil
}

func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
reply, thoughts string, calls []*genai.FunctionCall, err error) {
func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *trajectory.Span) (
reply string, calls []*genai.FunctionCall, err error) {
if len(resp.Candidates) == 0 || resp.Candidates[0] == nil {
err = fmt.Errorf("empty model response")
if resp.PromptFeedback != nil {
Expand All @@ -320,6 +321,13 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
err = fmt.Errorf("unexpected reply fields (%+v)", *candidate)
return
}
if resp.UsageMetadata != nil {
// We add ToolUsePromptTokenCount just in case, but Gemini does not use/set it.
span.InputTokens = int(resp.UsageMetadata.PromptTokenCount) +
int(resp.UsageMetadata.ToolUsePromptTokenCount)
span.OutputTokens = int(resp.UsageMetadata.CandidatesTokenCount)
span.OutputThoughtsTokens = int(resp.UsageMetadata.ThoughtsTokenCount)
}
for _, part := range candidate.Content.Parts {
// We don't expect to receive these now.
if part.VideoMetadata != nil || part.InlineData != nil ||
Expand All @@ -331,14 +339,19 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
if part.FunctionCall != nil {
calls = append(calls, part.FunctionCall)
} else if part.Thought {
thoughts += part.Text
span.Thoughts += part.Text
} else {
reply += part.Text
}
}
if strings.TrimSpace(reply) == "" {
reply = ""
}
// If there is any reply along with tool calls, append it to thoughts.
// Otherwise it won't show up in the trajectory anywhere.
if len(calls) != 0 && reply != "" {
span.Thoughts += "\n" + reply
}
return
}

Expand All @@ -355,14 +368,22 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
continue
}
if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests &&
strings.Contains(apiErr.Message, "Quota exceeded for metric") &&
strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
return resp, &modelQuotaError{ctx.modelName(a.Model)}
strings.Contains(apiErr.Message, "Quota exceeded for metric") {
if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil {
sec, _ := strconv.Atoi(match[1])
time.Sleep(time.Duration(sec+1) * time.Second)
continue
}
if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
return resp, &modelQuotaError{ctx.modelName(a.Model)}
}
}
return resp, err
}
}

var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]")

func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig,
req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) {
type Cached struct {
Expand All @@ -389,10 +410,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) {
ctx.requireNotEmpty(a.Name, "Model", a.Model)
ctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
a.Temperature = float32(temp)
a.Temperature = float64(temp)
}
if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 {
ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]")
}
if a.Candidates < 0 || a.Candidates > 100 {
ctx.errorf(a.Name, "Candidates must be in the range [0, 100]")
Expand Down
Loading
Loading