Skip to content

Commit 5492569

Browse files
committed
pkg/aflow: add BadCallError
The error allows tools to communicate that an error is not an infrastructure error that must fail the whole workflow, but rather a bad tool invocation by an LLM (e.g. asking for a non-existent file contents). Previously in the codesearcher tool we used a separate Missing bool to communicate that. With the error everything just becomes cleaner and nicer. The errors also allows all other tools to communicate any errors to the LLM when the normal results cannot be provided and don't make sense.
1 parent ef14e4f commit 5492569

15 files changed

+195
-83
lines changed

pkg/aflow/func_tool.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
package aflow
55

66
import (
7+
"errors"
8+
79
"github.com/google/syzkaller/pkg/aflow/trajectory"
810
"google.golang.org/genai"
911
)
@@ -24,6 +26,17 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State,
2426
}
2527
}
2628

29+
// BadCallError creates an error that means that LLM made a bad tool call,
30+
// the provided message will be returned to the LLM as an error,
31+
// instead of failing the whole workflow.
32+
func BadCallError(message string) error {
33+
return &badCallError{errors.New(message)}
34+
}
35+
36+
type badCallError struct {
37+
error
38+
}
39+
2740
type funcTool[State, Args, Results any] struct {
2841
Name string
2942
Description string

pkg/aflow/func_tool_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright 2026 syzkaller project authors. All rights reserved.
2+
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3+
4+
package aflow
5+
6+
import (
7+
"context"
8+
"errors"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/google/syzkaller/pkg/aflow/trajectory"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"google.golang.org/genai"
16+
)
17+
18+
func TestToolErrors(t *testing.T) {
19+
type flowOutputs struct {
20+
Reply string
21+
}
22+
type toolArgs struct {
23+
CallError bool `jsonschema:"call error"`
24+
}
25+
flows := make(map[string]*Flow)
26+
err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
27+
{
28+
Root: &LLMAgent{
29+
Name: "smarty",
30+
Model: "model",
31+
Reply: "Reply",
32+
Temperature: 0,
33+
Instruction: "Do something!",
34+
Prompt: "Prompt",
35+
Tools: []Tool{
36+
NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) {
37+
if args.CallError {
38+
return struct{}{}, BadCallError("you are wrong")
39+
}
40+
return struct{}{}, errors.New("hard error")
41+
}, "tool 1 description"),
42+
},
43+
},
44+
},
45+
})
46+
require.NoError(t, err)
47+
replySeq := 0
48+
stub := &stubContext{
49+
generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
50+
*genai.GenerateContentResponse, error) {
51+
52+
replySeq++
53+
switch replySeq {
54+
case 1:
55+
return &genai.GenerateContentResponse{
56+
Candidates: []*genai.Candidate{{
57+
Content: &genai.Content{
58+
Role: string(genai.RoleModel),
59+
Parts: []*genai.Part{
60+
{
61+
FunctionCall: &genai.FunctionCall{
62+
ID: "id0",
63+
Name: "faulty",
64+
Args: map[string]any{
65+
"CallError": true,
66+
},
67+
},
68+
},
69+
}}}}}, nil
70+
case 2:
71+
assert.Equal(t, req[2], &genai.Content{
72+
Role: string(genai.RoleUser),
73+
Parts: []*genai.Part{
74+
{
75+
FunctionResponse: &genai.FunctionResponse{
76+
ID: "id0",
77+
Name: "faulty",
78+
Response: map[string]any{
79+
"error": "you are wrong",
80+
},
81+
},
82+
}}})
83+
return &genai.GenerateContentResponse{
84+
Candidates: []*genai.Candidate{{
85+
Content: &genai.Content{
86+
Role: string(genai.RoleModel),
87+
Parts: []*genai.Part{
88+
{
89+
FunctionCall: &genai.FunctionCall{
90+
ID: "id0",
91+
Name: "faulty",
92+
Args: map[string]any{
93+
"CallError": false,
94+
},
95+
},
96+
},
97+
}}}}}, nil
98+
default:
99+
t.Fatal("unexpected LLM calls")
100+
return nil, nil
101+
}
102+
},
103+
}
104+
ctx := context.WithValue(context.Background(), stubContextKey, stub)
105+
workdir := t.TempDir()
106+
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
107+
require.NoError(t, err)
108+
onEvent := func(span *trajectory.Span) error { return nil }
109+
_, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent)
110+
require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]")
111+
}

pkg/aflow/llm_agent.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,24 +246,25 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
246246
},
247247
})
248248
}
249+
appendError := func(message string) {
250+
appendPart(map[string]any{"error": message})
251+
}
249252
tool := tools[call.Name]
250253
if tool == nil {
251-
appendPart(map[string]any{
252-
"error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name),
253-
})
254+
appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name))
254255
continue
255256
}
256257
results, err := tool.execute(ctx, call.Args)
257258
if err != nil {
258-
if argsErr := new(toolArgsError); errors.As(err, &argsErr) {
259-
// LLM provided wrong arguments to the tool,
260-
// return the error back to the LLM instead of failing.
261-
appendPart(map[string]any{
262-
"error": err.Error(),
263-
})
259+
// LLM provided wrong arguments to the tool,
260+
// or the tool returned error message to the LLM.
261+
// Return the error back to the LLM instead of failing.
262+
if callErr := new(badCallError); errors.As(err, &callErr) {
263+
appendError(err.Error())
264264
continue
265265
}
266-
return nil, nil, err
266+
return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v",
267+
call.Name, err, call.Args)
267268
}
268269
appendPart(results)
269270
if a.Outputs != nil && tool == a.Outputs.tool {

pkg/aflow/schema.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
6363
f, ok := m[name]
6464
if !ok {
6565
if tool {
66-
return val, &toolArgsError{fmt.Errorf("missing argument %q", name)}
66+
return val, BadCallError(fmt.Sprintf("missing argument %q", name))
6767
} else {
6868
return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
6969
}
@@ -79,8 +79,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
7979
field.Set(reflect.ValueOf(f))
8080
} else {
8181
if tool {
82-
return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v",
83-
name, f, field.Type().Name())}
82+
return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v",
83+
name, f, field.Type().Name()))
8484
} else {
8585
return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
8686
name, f, field.Type().Name())
@@ -93,8 +93,6 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
9393
return val, nil
9494
}
9595

96-
type toolArgsError struct{ error }
97-
9896
// foreachField iterates over all public fields of the struct provided in data.
9997
func foreachField(data any) iter.Seq2[string, reflect.Value] {
10098
return func(yield func(string, reflect.Value) bool) {

pkg/aflow/tool/codesearcher/codesearcher.go

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ type dirIndexArgs struct {
6868
}
6969

7070
type dirIndexResult struct {
71-
Missing bool `jsonschema:"Set to true if the requested directory does not exist."`
7271
Subdirs []string `jsonschema:"List of direct subdirectories."`
7372
Files []string `jsonschema:"List of source files."`
7473
}
@@ -78,7 +77,6 @@ type readFileArgs struct {
7877
}
7978

8079
type readFileResult struct {
81-
Missing bool `jsonschema:"Set to true if the requested file does not exist."`
8280
Contents string `jsonschema:"File contents."`
8381
}
8482

@@ -87,7 +85,6 @@ type fileIndexArgs struct {
8785
}
8886

8987
type fileIndexResult struct {
90-
Missing bool `jsonschema:"Set to true if the file with the given name does not exist."`
9188
Entities []indexEntity `jsonschema:"List of entites defined in the file."`
9289
}
9390

@@ -103,7 +100,6 @@ type defCommentArgs struct {
103100
}
104101

105102
type defCommentResult struct {
106-
Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."`
107103
Kind string `jsonschema:"Kind of the entity: function, struct, variable."`
108104
Comment string `jsonschema:"Source comment for the entity."`
109105
}
@@ -117,7 +113,6 @@ type defSourceArgs struct {
117113

118114
// nolint: lll
119115
type defSourceResult struct {
120-
Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."`
121116
SourceFile string `jsonschema:"Source file path where the entity is defined."`
122117
SourceCode string `jsonschema:"Source code of the entity definition. It is prefixed with line numbers, so that they can be referenced in other tool invocations."`
123118
}
@@ -159,29 +154,23 @@ func prepare(ctx *aflow.Context, args prepareArgs) (prepareResult, error) {
159154
}
160155

161156
func dirIndex(ctx *aflow.Context, state prepareResult, args dirIndexArgs) (dirIndexResult, error) {
162-
ok, subdirs, files, err := state.Index.DirIndex(args.Dir)
163-
res := dirIndexResult{
164-
Missing: !ok,
157+
subdirs, files, err := state.Index.DirIndex(args.Dir)
158+
return dirIndexResult{
165159
Subdirs: subdirs,
166160
Files: files,
167-
}
168-
return res, err
161+
}, err
169162
}
170163

171164
func readFile(ctx *aflow.Context, state prepareResult, args readFileArgs) (readFileResult, error) {
172-
ok, contents, err := state.Index.ReadFile(args.File)
173-
res := readFileResult{
174-
Missing: !ok,
165+
contents, err := state.Index.ReadFile(args.File)
166+
return readFileResult{
175167
Contents: contents,
176-
}
177-
return res, err
168+
}, err
178169
}
179170

180171
func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fileIndexResult, error) {
181-
ok, entities, err := state.Index.FileIndex(args.SourceFile)
182-
res := fileIndexResult{
183-
Missing: !ok,
184-
}
172+
entities, err := state.Index.FileIndex(args.SourceFile)
173+
res := fileIndexResult{}
185174
for _, ent := range entities {
186175
res.Entities = append(res.Entities, indexEntity{
187176
Kind: ent.Kind,
@@ -193,10 +182,8 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil
193182

194183
func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) {
195184
info, err := state.Index.DefinitionComment(args.SourceFile, args.Name)
196-
if err != nil || info == nil {
197-
return defCommentResult{
198-
Missing: info == nil,
199-
}, err
185+
if err != nil {
186+
return defCommentResult{}, err
200187
}
201188
return defCommentResult{
202189
Kind: info.Kind,
@@ -206,10 +193,8 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA
206193

207194
func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) {
208195
info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines)
209-
if err != nil || info == nil {
210-
return defSourceResult{
211-
Missing: info == nil,
212-
}, err
196+
if err != nil {
197+
return defSourceResult{}, err
213198
}
214199
return defSourceResult{
215200
SourceFile: info.File,

0 commit comments

Comments
 (0)