Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pkg/aflow/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ func (err *modelQuotaError) Error() string {
return fmt.Sprintf("model %q is over daily quota", err.model)
}

func isTokenOverflowError(err error) bool {
var overflowErr *tokenOverflowError
return errors.As(err, &overflowErr)
}

type tokenOverflowError struct {
error
}

// QuotaResetTime returns the time when RPD quota will be reset
// for a quota overflow happened at time t.
func QuotaResetTime(t time.Time) time.Time {
Expand Down
52 changes: 42 additions & 10 deletions pkg/aflow/llm_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ Note: if you already provided you final reply, you will need to provide it again
Or did you want to call some other tools, but did not actually do that?
`

const llmAnswerNow = `
Provide a best-effort answer to the original question with all of the information
you have so far without calling any more tools!
`

type llmOutputs struct {
tool Tool
provideOutputs func(*verifyContext, string, bool)
Expand Down Expand Up @@ -178,22 +183,43 @@ func (a *LLMAgent) executeOne(ctx *Context, candidate int) (string, map[string]a
func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool,
prompt string, candidate int) (string, map[string]any, error) {
var outputs map[string]any
answerNow := false
req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
for {
reqSpan := &trajectory.Span{
span := &trajectory.Span{
Type: trajectory.SpanLLM,
Name: a.Name,
Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(reqSpan); err != nil {
if err := ctx.startSpan(span); err != nil {
return "", nil, err
}
resp, err := a.generateContent(ctx, cfg, req, candidate)
if err != nil {
return "", nil, ctx.finishSpan(reqSpan, err)
resp, respErr := a.generateContent(ctx, cfg, req, candidate)
if respErr != nil {
span.Error = respErr.Error()
if err := ctx.finishSpan(span, nil); err != nil {
return "", nil, err
}
// Input overflows maximum number of tokens.
// If this is an LLMTool, we remove the last tool reply,
// and replace it with an order to answer right now.
if isTokenOverflowError(respErr) &&
a.Reply == llmToolReply &&
len(req) >= 3 &&
!answerNow {
answerNow = true
cfg.ToolConfig = &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeNone,
},
}
req[len(req)-1] = genai.NewContentFromText(llmAnswerNow, genai.RoleUser)
continue
}
return "", nil, respErr
}
reply, calls, respErr := a.parseResponse(resp, reqSpan)
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
reply, calls, respErr := a.parseResponse(resp, span)
if err := ctx.finishSpan(span, respErr); err != nil {
return "", nil, err
}
req = append(req, resp.Candidates[0].Content)
Expand Down Expand Up @@ -361,13 +387,15 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
for try := 0; ; try++ {
resp, err := a.generateContentCached(ctx, cfg, req, candidate)
var apiErr genai.APIError
if err != nil && try < 100 && errors.As(err, &apiErr) &&
apiErr.Code == http.StatusServiceUnavailable {
if err == nil || !errors.As(err, &apiErr) {
return resp, err
}
if try < 100 && apiErr.Code == http.StatusServiceUnavailable {
time.Sleep(backoff)
backoff = min(backoff+time.Second, 10*time.Second)
continue
}
if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests &&
if apiErr.Code == http.StatusTooManyRequests &&
strings.Contains(apiErr.Message, "Quota exceeded for metric") {
if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil {
sec, _ := strconv.Atoi(match[1])
Expand All @@ -378,6 +406,10 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
return resp, &modelQuotaError{ctx.modelName(a.Model)}
}
}
if apiErr.Code == http.StatusBadRequest &&
strings.Contains(apiErr.Message, "The input token count exceeds the maximum") {
return resp, &tokenOverflowError{err}
}
return resp, err
}
}
Expand Down
20 changes: 18 additions & 2 deletions pkg/aflow/llm_tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package aflow

import (
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -40,7 +42,7 @@ func TestLLMTool(t *testing.T) {
NewFuncTool("researcher-tool", func(ctx *Context, state inputs, args toolArgs) (struct{}, error) {
// State passed all the way from the workflow inputs.
assert.Equal(t, state.Input, 42)
assert.True(t, args.Something == "subtool input 1" || args.Something == "subtool input 2",
assert.True(t, strings.HasPrefix(args.Something, "subtool input"),
"args.Something=%q", args.Something)
return struct{}{}, nil
}, "researcher-tool description"),
Expand Down Expand Up @@ -84,13 +86,27 @@ func TestLLMTool(t *testing.T) {
},
&genai.Part{
FunctionCall: &genai.FunctionCall{
ID: "id2",
ID: "id3",
Name: "researcher-tool",
Args: map[string]any{
"Something": "subtool input 2",
},
},
},
// Now model input token overflow.
&genai.Part{
FunctionCall: &genai.FunctionCall{
ID: "id4",
Name: "researcher-tool",
Args: map[string]any{
"Something": "subtool input 3",
},
},
},
genai.APIError{
Code: http.StatusBadRequest,
Message: "The input token count exceeds the maximum number of tokens allowed 1048576.",
},
genai.NewPartFromText("Still nothing."),
// Main returns result.
genai.NewPartFromText("YES"),
Expand Down
7 changes: 5 additions & 2 deletions pkg/aflow/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"flag"
"os"
"path/filepath"
"slices"
"testing"
"time"

Expand All @@ -33,7 +34,7 @@ func testFlow[Inputs, Outputs any](t *testing.T, inputs map[string]any, result a
require.NoError(t, err)
type llmRequest struct {
Model string
Config *genai.GenerateContentConfig
Config genai.GenerateContentConfig
Request []*genai.Content
}
var requests []llmRequest
Expand All @@ -45,7 +46,9 @@ func testFlow[Inputs, Outputs any](t *testing.T, inputs map[string]any, result a
},
generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
requests = append(requests, llmRequest{model, cfg, req})
// Copy config and req slices, so that future changes to these objects
// don't affect our stored requests.
requests = append(requests, llmRequest{model, *cfg, slices.Clone(req)})
require.NotEmpty(t, llmReplies, "unexpected LLM call")
reply := llmReplies[0]
llmReplies = llmReplies[1:]
Expand Down
Loading
Loading