Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/aflow/func_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package aflow

import (
"errors"
"fmt"

"google.golang.org/genai"
)
Expand All @@ -28,8 +28,8 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State,
// BadCallError creates an error that means that LLM made a bad tool call,
// the provided message will be returned to the LLM as an error,
// instead of failing the whole workflow.
func BadCallError(message string) error {
return &badCallError{errors.New(message)}
func BadCallError(message string, args ...any) error {
return &badCallError{fmt.Errorf(message, args...)}
}

type badCallError struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/aflow/llm_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
if err := ctx.startSpan(span); err != nil {
return nil, nil, err
}
toolErr := BadCallError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name))
toolErr := BadCallError("tool %q does not exist, please correct the name", call.Name)
tool := tools[call.Name]
if tool != nil {
span.Results, toolErr = tool.execute(ctx, call.Args)
Expand Down
10 changes: 5 additions & 5 deletions pkg/aflow/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
continue
}
if tool {
return val, BadCallError(fmt.Sprintf("missing argument %q", name))
return val, BadCallError("missing argument %q", name)
} else {
return val, fmt.Errorf("%T: field %q is not present when converting map", val, name)
}
Expand All @@ -107,8 +107,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
iv := fValue.Convert(field.Type())
if fv := iv.Convert(fType); !fValue.Equal(fv) {
if tool {
return val, BadCallError(fmt.Sprintf("argument %v: float value truncated from %v to %v",
name, f, iv.Interface()))
return val, BadCallError("argument %v: float value truncated from %v to %v",
name, f, iv.Interface())
} else {
return val, fmt.Errorf("%T: field %v: float value truncated from %v to %v",
val, name, f, iv.Interface())
Expand All @@ -119,8 +119,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
field.Set(fValue)
} else {
if tool {
return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v",
name, f, field.Type().Name()))
return val, BadCallError("argument %q has wrong type: got %T, want %v",
name, f, field.Type().Name())
} else {
return val, fmt.Errorf("%T: field %q has wrong type: got %T, want %v",
val, name, f, field.Type().Name())
Expand Down
16 changes: 14 additions & 2 deletions syz-agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"maps"
_ "net/http/pprof"
"path/filepath"
"slices"
"sync"
"time"

Expand Down Expand Up @@ -51,6 +52,8 @@ type Config struct {
FixedRepository string `json:"repo"`
// Use this LLM model (for testing, if empty use workflow-default model).
Model string `json:"model"`
// Names of workflows to serve (all if not set, mainly for testing/local experimentation).
Workflows []string `json:"workflows"`
}

func main() {
Expand Down Expand Up @@ -175,21 +178,25 @@ func (s *Server) poll(ctx context.Context) (bool, error) {
CodeRevision: prog.GitRevision,
}
for _, flow := range aflow.Flows {
if s.modelOverQuota(flow) {
if len(s.cfg.Workflows) != 0 && !slices.Contains(s.cfg.Workflows, flow.Name) ||
s.modelOverQuota(flow) {
continue
}
req.Workflows = append(req.Workflows, dashapi.AIWorkflow{
Type: flow.Type,
Name: flow.Name,
})
}
log.Logf(0, "querying jobs for %v", req.Workflows)
resp, err := s.dash.AIJobPoll(req)
if err != nil {
return false, err
}
if resp.ID == "" {
return false, nil
}
log.Logf(0, "starting job %v %v", resp.Workflow, resp.ID)
defer log.Logf(0, "finished job %v %v", resp.Workflow, resp.ID)
doneReq := &dashapi.AIJobDoneReq{
ID: resp.ID,
}
Expand All @@ -208,10 +215,14 @@ func (s *Server) poll(ctx context.Context) (bool, error) {
// the dashboard at all. For the dashboard it will look like
// the server has crashed while executing the job, and it should
// eventually retry it on common grounds.
s.overQuotaModels[model] = time.Now()
now := time.Now()
s.overQuotaModels[model] = now
log.Logf(0, "model %v is over daily quota until %v",
model, aflow.QuotaResetTime(now))
return true, nil
}
}
log.Logf(0, "done executing job %v %v", resp.Workflow, resp.ID)
if err := s.dash.AIJobDone(doneReq); err != nil {
return false, err
}
Expand Down Expand Up @@ -261,6 +272,7 @@ func (s *Server) modelOverQuota(flow *aflow.Flow) bool {
func (s *Server) resetModelQuota() {
for model, when := range s.overQuotaModels {
if aflow.QuotaResetTime(when).After(time.Now()) {
log.Logf(0, "model %v daily quota is replenished", model)
delete(s.overQuotaModels, model)
}
}
Expand Down
Loading