Skip to content

Commit 43e1df1

Browse files
committed
pkg/aflow: handle input token overflow for LLM tools
Handle LLM tool input token overflow by removing the last tool reply, and replacing it with an order to answer right now. I've seen an LLM tool went into too deap research and in the end just overflowed input tokens. It could provide at least some answer instead.
1 parent b441fd8 commit 43e1df1

File tree

6 files changed

+352
-25
lines changed

6 files changed

+352
-25
lines changed

pkg/aflow/execute.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ func (err *modelQuotaError) Error() string {
105105
return fmt.Sprintf("model %q is over daily quota", err.model)
106106
}
107107

108+
func isTokenOverflowError(err error) bool {
109+
var overflowErr *tokenOverflowError
110+
return errors.As(err, &overflowErr)
111+
}
112+
113+
type tokenOverflowError struct {
114+
error
115+
}
116+
108117
// QuotaResetTime returns the time when RPD quota will be reset
109118
// for a quota overflow happened at time t.
110119
func QuotaResetTime(t time.Time) time.Time {

pkg/aflow/llm_agent.go

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ Note: if you already provided you final reply, you will need to provide it again
110110
Or did you want to call some other tools, but did not actually do that?
111111
`
112112

113+
const llmAnswerNow = `
114+
Provide a best-effort answer to the original question with all of the information
115+
you have so far without calling any more tools!
116+
`
117+
113118
type llmOutputs struct {
114119
tool Tool
115120
provideOutputs func(*verifyContext, string, bool)
@@ -178,22 +183,43 @@ func (a *LLMAgent) executeOne(ctx *Context, candidate int) (string, map[string]a
178183
func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool,
179184
prompt string, candidate int) (string, map[string]any, error) {
180185
var outputs map[string]any
186+
answerNow := false
181187
req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
182188
for {
183-
reqSpan := &trajectory.Span{
189+
span := &trajectory.Span{
184190
Type: trajectory.SpanLLM,
185191
Name: a.Name,
186192
Model: ctx.modelName(a.Model),
187193
}
188-
if err := ctx.startSpan(reqSpan); err != nil {
194+
if err := ctx.startSpan(span); err != nil {
189195
return "", nil, err
190196
}
191-
resp, err := a.generateContent(ctx, cfg, req, candidate)
192-
if err != nil {
193-
return "", nil, ctx.finishSpan(reqSpan, err)
197+
resp, respErr := a.generateContent(ctx, cfg, req, candidate)
198+
if respErr != nil {
199+
span.Error = respErr.Error()
200+
if err := ctx.finishSpan(span, nil); err != nil {
201+
return "", nil, err
202+
}
203+
// Input overflows maximum number of tokens.
204+
// If this is an LLMTool, we remove the last tool reply,
205+
// and replace it with an order to answer right now.
206+
if isTokenOverflowError(respErr) &&
207+
a.Reply == llmToolReply &&
208+
len(req) >= 3 &&
209+
!answerNow {
210+
answerNow = true
211+
cfg.ToolConfig = &genai.ToolConfig{
212+
FunctionCallingConfig: &genai.FunctionCallingConfig{
213+
Mode: genai.FunctionCallingConfigModeNone,
214+
},
215+
}
216+
req[len(req)-1] = genai.NewContentFromText(llmAnswerNow, genai.RoleUser)
217+
continue
218+
}
219+
return "", nil, respErr
194220
}
195-
reply, calls, respErr := a.parseResponse(resp, reqSpan)
196-
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
221+
reply, calls, respErr := a.parseResponse(resp, span)
222+
if err := ctx.finishSpan(span, respErr); err != nil {
197223
return "", nil, err
198224
}
199225
req = append(req, resp.Candidates[0].Content)
@@ -361,13 +387,15 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
361387
for try := 0; ; try++ {
362388
resp, err := a.generateContentCached(ctx, cfg, req, candidate)
363389
var apiErr genai.APIError
364-
if err != nil && try < 100 && errors.As(err, &apiErr) &&
365-
apiErr.Code == http.StatusServiceUnavailable {
390+
if err == nil || !errors.As(err, &apiErr) {
391+
return resp, err
392+
}
393+
if try < 100 && apiErr.Code == http.StatusServiceUnavailable {
366394
time.Sleep(backoff)
367395
backoff = min(backoff+time.Second, 10*time.Second)
368396
continue
369397
}
370-
if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests &&
398+
if apiErr.Code == http.StatusTooManyRequests &&
371399
strings.Contains(apiErr.Message, "Quota exceeded for metric") {
372400
if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil {
373401
sec, _ := strconv.Atoi(match[1])
@@ -378,6 +406,10 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
378406
return resp, &modelQuotaError{ctx.modelName(a.Model)}
379407
}
380408
}
409+
if apiErr.Code == http.StatusBadRequest &&
410+
strings.Contains(apiErr.Message, "The input token count exceeds the maximum") {
411+
return resp, &tokenOverflowError{err}
412+
}
381413
return resp, err
382414
}
383415
}

pkg/aflow/llm_tool_test.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
package aflow
55

66
import (
7+
"net/http"
8+
"strings"
79
"testing"
810

911
"github.com/stretchr/testify/assert"
@@ -40,7 +42,7 @@ func TestLLMTool(t *testing.T) {
4042
NewFuncTool("researcher-tool", func(ctx *Context, state inputs, args toolArgs) (struct{}, error) {
4143
// State passed all the way from the workflow inputs.
4244
assert.Equal(t, state.Input, 42)
43-
assert.True(t, args.Something == "subtool input 1" || args.Something == "subtool input 2",
45+
assert.True(t, strings.HasPrefix(args.Something, "subtool input"),
4446
"args.Something=%q", args.Something)
4547
return struct{}{}, nil
4648
}, "researcher-tool description"),
@@ -84,13 +86,27 @@ func TestLLMTool(t *testing.T) {
8486
},
8587
&genai.Part{
8688
FunctionCall: &genai.FunctionCall{
87-
ID: "id2",
89+
ID: "id3",
8890
Name: "researcher-tool",
8991
Args: map[string]any{
9092
"Something": "subtool input 2",
9193
},
9294
},
9395
},
96+
// Now model input token overflow.
97+
&genai.Part{
98+
FunctionCall: &genai.FunctionCall{
99+
ID: "id4",
100+
Name: "researcher-tool",
101+
Args: map[string]any{
102+
"Something": "subtool input 3",
103+
},
104+
},
105+
},
106+
genai.APIError{
107+
Code: http.StatusBadRequest,
108+
Message: "The input token count exceeds the maximum number of tokens allowed 1048576.",
109+
},
94110
genai.NewPartFromText("Still nothing."),
95111
// Main returns result.
96112
genai.NewPartFromText("YES"),

pkg/aflow/runner_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"flag"
1010
"os"
1111
"path/filepath"
12+
"slices"
1213
"testing"
1314
"time"
1415

@@ -33,7 +34,7 @@ func testFlow[Inputs, Outputs any](t *testing.T, inputs map[string]any, result a
3334
require.NoError(t, err)
3435
type llmRequest struct {
3536
Model string
36-
Config *genai.GenerateContentConfig
37+
Config genai.GenerateContentConfig
3738
Request []*genai.Content
3839
}
3940
var requests []llmRequest
@@ -45,7 +46,9 @@ func testFlow[Inputs, Outputs any](t *testing.T, inputs map[string]any, result a
4546
},
4647
generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
4748
*genai.GenerateContentResponse, error) {
48-
requests = append(requests, llmRequest{model, cfg, req})
49+
// Copy config and req slices, so that future changes to these objects
50+
// don't affect our stored requests.
51+
requests = append(requests, llmRequest{model, *cfg, slices.Clone(req)})
4952
require.NotEmpty(t, llmReplies, "unexpected LLM call")
5053
reply := llmReplies[0]
5154
llmReplies = llmReplies[1:]

0 commit comments

Comments
 (0)