Skip to content

Commit e27b2c4

Browse files
committed
pkg/aflow: refactor tests
Add helper function that executes test workflows, compares results (trajectory, LLM requests) against golden files, and if requested updates these golden files.
1 parent d09a306 commit e27b2c4

10 files changed

+3456
-1355
lines changed

pkg/aflow/flow_test.go

Lines changed: 245 additions & 1254 deletions
Large diffs are not rendered by default.

pkg/aflow/func_tool_test.go

Lines changed: 36 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,9 @@
44
package aflow
55

66
import (
7-
"context"
87
"errors"
9-
"path/filepath"
108
"testing"
11-
"time"
129

13-
"github.com/google/syzkaller/pkg/aflow/trajectory"
14-
"github.com/stretchr/testify/assert"
15-
"github.com/stretchr/testify/require"
1610
"google.golang.org/genai"
1711
)
1812

@@ -23,90 +17,43 @@ func TestToolErrors(t *testing.T) {
2317
type toolArgs struct {
2418
CallError bool `jsonschema:"call error"`
2519
}
26-
flows := make(map[string]*Flow)
27-
err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
28-
{
29-
Root: &LLMAgent{
30-
Name: "smarty",
31-
Model: "model",
32-
Reply: "Reply",
33-
Temperature: 0,
34-
Instruction: "Do something!",
35-
Prompt: "Prompt",
36-
Tools: []Tool{
37-
NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) {
38-
if args.CallError {
39-
return struct{}{}, BadCallError("you are wrong")
40-
}
41-
return struct{}{}, errors.New("hard error")
42-
}, "tool 1 description"),
43-
},
20+
testFlow[struct{}, flowOutputs](t, nil,
21+
"tool faulty failed: error: hard error\nargs: map[CallError:false]",
22+
&LLMAgent{
23+
Name: "smarty",
24+
Model: "model",
25+
Reply: "Reply",
26+
Temperature: 0,
27+
Instruction: "Do something!",
28+
Prompt: "Prompt",
29+
Tools: []Tool{
30+
NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) {
31+
if args.CallError {
32+
return struct{}{}, BadCallError("you are wrong")
33+
}
34+
return struct{}{}, errors.New("hard error")
35+
}, "tool 1 description"),
4436
},
4537
},
46-
})
47-
require.NoError(t, err)
48-
replySeq := 0
49-
stub := &stubContext{
50-
// nolint:dupl
51-
generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
52-
*genai.GenerateContentResponse, error) {
53-
replySeq++
54-
switch replySeq {
55-
case 1:
56-
return &genai.GenerateContentResponse{
57-
Candidates: []*genai.Candidate{{
58-
Content: &genai.Content{
59-
Role: string(genai.RoleModel),
60-
Parts: []*genai.Part{
61-
{
62-
FunctionCall: &genai.FunctionCall{
63-
ID: "id0",
64-
Name: "faulty",
65-
Args: map[string]any{
66-
"CallError": true,
67-
},
68-
},
69-
},
70-
}}}}}, nil
71-
case 2:
72-
assert.Equal(t, req[2], &genai.Content{
73-
Role: string(genai.RoleUser),
74-
Parts: []*genai.Part{
75-
{
76-
FunctionResponse: &genai.FunctionResponse{
77-
ID: "id0",
78-
Name: "faulty",
79-
Response: map[string]any{
80-
"error": "you are wrong",
81-
},
82-
},
83-
}}})
84-
return &genai.GenerateContentResponse{
85-
Candidates: []*genai.Candidate{{
86-
Content: &genai.Content{
87-
Role: string(genai.RoleModel),
88-
Parts: []*genai.Part{
89-
{
90-
FunctionCall: &genai.FunctionCall{
91-
ID: "id0",
92-
Name: "faulty",
93-
Args: map[string]any{
94-
"CallError": false,
95-
},
96-
},
97-
},
98-
}}}}}, nil
99-
default:
100-
t.Fatal("unexpected LLM calls")
101-
return nil, nil
102-
}
38+
[]any{
39+
&genai.Part{
40+
FunctionCall: &genai.FunctionCall{
41+
ID: "id0",
42+
Name: "faulty",
43+
Args: map[string]any{
44+
"CallError": true,
45+
},
46+
},
47+
},
48+
&genai.Part{
49+
FunctionCall: &genai.FunctionCall{
50+
ID: "id0",
51+
Name: "faulty",
52+
Args: map[string]any{
53+
"CallError": false,
54+
},
55+
},
56+
},
10357
},
104-
}
105-
ctx := context.WithValue(context.Background(), stubContextKey, stub)
106-
workdir := t.TempDir()
107-
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now)
108-
require.NoError(t, err)
109-
onEvent := func(span *trajectory.Span) error { return nil }
110-
_, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent)
111-
require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]")
58+
)
11259
}

pkg/aflow/runner_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2026 syzkaller project authors. All rights reserved.
2+
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3+
4+
package aflow
5+
6+
import (
7+
"context"
8+
"encoding/json"
9+
"flag"
10+
"path/filepath"
11+
"testing"
12+
"time"
13+
14+
"github.com/google/syzkaller/pkg/aflow/trajectory"
15+
"github.com/google/syzkaller/pkg/osutil"
16+
"github.com/stretchr/testify/require"
17+
"google.golang.org/genai"
18+
)
19+
20+
var flagUpdate = flag.Bool("update", false, "update golden test files to match the actual execution")
21+
22+
// testFlow executes the provided test workflow by returning LLM replies from llmReplies.
23+
// The result can be either a map[string]any with Outputs fields, or an error,
24+
// if an error is expected as the result of the execution.
25+
// llmReplies objects can be either *genai.Part, []*genai.Part, or an error.
26+
// Requests sent to LLM are compared against "testdata/TestName.llm.json" file.
27+
// Resulting trajectory is compared against "testdata/TestName.trajectory.json" file.
28+
// If -update flag is provided, the golden testdata files are updated to match the actual execution.
29+
func testFlow[Inputs, Outputs any](t *testing.T, inputs map[string]any, result any, root Action, llmReplies []any) {
30+
flows := make(map[string]*Flow)
31+
err := register[Inputs, Outputs]("test", "description", flows, []*Flow{{Root: root}})
32+
require.NoError(t, err)
33+
type llmRequest struct {
34+
Model string
35+
Config *genai.GenerateContentConfig
36+
Request []*genai.Content
37+
}
38+
var requests []llmRequest
39+
var stubTime time.Time
40+
stub := &stubContext{
41+
timeNow: func() time.Time {
42+
stubTime = stubTime.Add(time.Second)
43+
return stubTime
44+
},
45+
generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
46+
*genai.GenerateContentResponse, error) {
47+
requests = append(requests, llmRequest{model, cfg, req})
48+
require.NotEmpty(t, llmReplies, "unexpected LLM call")
49+
reply := llmReplies[0]
50+
llmReplies = llmReplies[1:]
51+
switch reply := reply.(type) {
52+
case error:
53+
return nil, reply
54+
case *genai.Part:
55+
return &genai.GenerateContentResponse{
56+
Candidates: []*genai.Candidate{{Content: &genai.Content{
57+
Role: string(genai.RoleUser),
58+
Parts: []*genai.Part{reply},
59+
}}}}, nil
60+
case []*genai.Part:
61+
return &genai.GenerateContentResponse{
62+
Candidates: []*genai.Candidate{{Content: &genai.Content{
63+
Role: string(genai.RoleUser),
64+
Parts: reply,
65+
}}}}, nil
66+
default:
67+
t.Fatalf("bad LLM reply type %T", reply)
68+
return nil, nil
69+
}
70+
},
71+
}
72+
var spans []*trajectory.Span
73+
onEvent := func(span *trajectory.Span) error {
74+
spans = append(spans, span)
75+
return nil
76+
}
77+
ctx := context.WithValue(context.Background(), stubContextKey, stub)
78+
workdir := t.TempDir()
79+
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now)
80+
require.NoError(t, err)
81+
if inputs == nil {
82+
inputs = map[string]any{}
83+
}
84+
got, err := flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
85+
switch result := result.(type) {
86+
case map[string]any:
87+
require.NoError(t, err)
88+
require.Equal(t, got, result)
89+
case string:
90+
require.Error(t, err)
91+
require.Equal(t, err.Error(), result)
92+
default:
93+
t.Fatalf("bad result type %T", result)
94+
}
95+
// We need to pass spans/requests via double marshal/unmarshal round-trip
96+
// b/c some values change during the first round-trip (int->float64, jsonschema).
97+
spansData, err := json.Marshal(spans)
98+
require.NoError(t, err)
99+
spans = nil
100+
require.NoError(t, json.Unmarshal(spansData, &spans))
101+
requestsData, err := json.Marshal(requests)
102+
require.NoError(t, err)
103+
requests = nil
104+
require.NoError(t, json.Unmarshal(requestsData, &requests))
105+
trajectoryFile := filepath.Join("testdata", t.Name()+".trajectory.json")
106+
requestsFile := filepath.Join("testdata", t.Name()+".llm.json")
107+
if *flagUpdate {
108+
require.NoError(t, osutil.WriteJSON(trajectoryFile, spans))
109+
require.NoError(t, osutil.WriteJSON(requestsFile, requests))
110+
}
111+
wantSpans, err := osutil.ReadJSON[[]*trajectory.Span](trajectoryFile)
112+
require.NoError(t, err)
113+
require.Equal(t, spans, wantSpans)
114+
wantRequests, err := osutil.ReadJSON[[]llmRequest](requestsFile)
115+
require.NoError(t, err)
116+
require.Equal(t, requests, wantRequests)
117+
require.Empty(t, llmReplies)
118+
}

0 commit comments

Comments
 (0)