Skip to content

Commit 2d5e421

Browse files
committed
pkg/aflow: allow to specify model per-flow
We may want to use a weaker model for some workflows. Allow to use different models for different workflows.
1 parent 1b03c2c commit 2d5e421

File tree

12 files changed

+90
-64
lines changed

12 files changed

+90
-64
lines changed

dashboard/app/ai.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,14 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan
207207
}
208208

209209
func apiAIJobPoll(ctx context.Context, req *dashapi.AIJobPollReq) (any, error) {
210-
if len(req.Workflows) == 0 || req.CodeRevision == "" || req.LLMModel == "" {
210+
if len(req.Workflows) == 0 || req.CodeRevision == "" {
211211
return nil, fmt.Errorf("invalid request")
212212
}
213+
for _, flow := range req.Workflows {
214+
if flow.Type == "" || flow.Name == "" || flow.LLMModel == "" {
215+
return nil, fmt.Errorf("invalid request")
216+
}
217+
}
213218
if err := aidb.UpdateWorkflows(ctx, req.Workflows); err != nil {
214219
return nil, fmt.Errorf("failed UpdateWorkflows: %w", err)
215220
}

dashboard/app/ai_test.go

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ func TestAIBugWorkflows(t *testing.T) {
6363

6464
_, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
6565
CodeRevision: prog.GitRevision,
66-
LLMModel: "smarty",
6766
Workflows: []dashapi.AIWorkflow{
68-
{Type: "patching", Name: "patching"},
69-
{Type: "patching", Name: "patching-foo"},
70-
{Type: "patching", Name: "patching-bar"},
67+
{Type: "patching", Name: "patching", LLMModel: "smarty"},
68+
{Type: "patching", Name: "patching-foo", LLMModel: "smarty"},
69+
{Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
7170
},
7271
})
7372
require.NoError(t, err)
@@ -77,25 +76,23 @@ func TestAIBugWorkflows(t *testing.T) {
7776

7877
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
7978
CodeRevision: prog.GitRevision,
80-
LLMModel: "smarty",
8179
Workflows: []dashapi.AIWorkflow{
82-
{Type: "patching", Name: "patching"},
83-
{Type: "patching", Name: "patching-bar"},
84-
{Type: "patching", Name: "patching-baz"},
85-
{Type: "assessment-kcsan", Name: "assessment-kcsan"},
80+
{Type: "patching", Name: "patching", LLMModel: "smarty"},
81+
{Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
82+
{Type: "patching", Name: "patching-baz", LLMModel: "smarty"},
83+
{Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
8684
},
8785
})
8886
require.NoError(t, err)
8987

9088
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
9189
CodeRevision: prog.GitRevision,
92-
LLMModel: "smarty",
9390
Workflows: []dashapi.AIWorkflow{
94-
{Type: "patching", Name: "patching"},
95-
{Type: "patching", Name: "patching-bar"},
96-
{Type: "patching", Name: "patching-qux"},
97-
{Type: "assessment-kcsan", Name: "assessment-kcsan"},
98-
{Type: "assessment-kcsan", Name: "assessment-kcsan-foo"},
91+
{Type: "patching", Name: "patching", LLMModel: "smarty"},
92+
{Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
93+
{Type: "patching", Name: "patching-qux", LLMModel: "smarty"},
94+
{Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
95+
{Type: "assessment-kcsan", Name: "assessment-kcsan-foo", LLMModel: "smarty"},
9996
},
10097
})
10198
require.NoError(t, err)
@@ -117,9 +114,8 @@ func TestAIJob(t *testing.T) {
117114

118115
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
119116
CodeRevision: prog.GitRevision,
120-
LLMModel: "smarty",
121117
Workflows: []dashapi.AIWorkflow{
122-
{Type: "assessment-kcsan", Name: "assessment-kcsan"},
118+
{Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
123119
},
124120
})
125121
require.NoError(t, err)
@@ -137,9 +133,8 @@ func TestAIJob(t *testing.T) {
137133

138134
resp2, err2 := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
139135
CodeRevision: prog.GitRevision,
140-
LLMModel: "smarty",
141136
Workflows: []dashapi.AIWorkflow{
142-
{Type: "assessment-kcsan", Name: "assessment-kcsan"},
137+
{Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
143138
},
144139
})
145140
require.NoError(t, err2)
@@ -214,9 +209,8 @@ func TestAIAssessmentKCSAN(t *testing.T) {
214209

215210
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
216211
CodeRevision: prog.GitRevision,
217-
LLMModel: "smarty",
218212
Workflows: []dashapi.AIWorkflow{
219-
{Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN)},
213+
{Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN), LLMModel: "smarty"},
220214
},
221215
})
222216
require.NoError(t, err)

dashboard/app/aidb/crud.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,13 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) {
128128
}
129129
job = jobs[0]
130130
}
131-
job.Started = spanner.NullTime{
132-
Time: TimeNow(ctx),
133-
Valid: true,
131+
job.Started = spanner.NullTime{Time: TimeNow(ctx), Valid: true}
132+
for _, flow := range req.Workflows {
133+
if job.Workflow == flow.Name {
134+
job.LLMModel = flow.LLMModel
135+
break
136+
}
134137
}
135-
job.LLMModel = req.LLMModel
136138
job.CodeRevision = req.CodeRevision
137139
mut, err := spanner.InsertOrUpdateStruct("Jobs", job)
138140
if err != nil {

dashboard/dashapi/ai.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ import (
99
)
1010

1111
type AIJobPollReq struct {
12-
LLMModel string // LLM model that will be used to execute jobs
1312
CodeRevision string // git commit of the syz-agent server
1413
Workflows []AIWorkflow
1514
}
1615

1716
type AIWorkflow struct {
18-
Type ai.WorkflowType
19-
Name string
17+
Type ai.WorkflowType
18+
Name string
19+
LLMModel string // LLM model that will be used to execute this workflow
2020
}
2121

2222
type AIJobPollResp struct {

pkg/aflow/execute.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ import (
1818
"google.golang.org/genai"
1919
)
2020

21-
// https://ai.google.dev/gemini-api/docs/models
22-
const DefaultModel = "gemini-3-pro-preview"
23-
21+
// Execute executes the given AI workflow with provided inputs and returns workflow outputs.
22+
// The model argument sets Gemini model name to execute the workflow.
23+
// The workdir argument should point to a dir owned by aflow to store private data,
24+
// it can be shared across parallel executions in the same process, and preferably
25+
// preserved across process restarts for caching purposes.
2426
func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any,
2527
cache *Cache, onEvent onEvent) (map[string]any, error) {
2628
if err := flow.checkInputs(inputs); err != nil {

pkg/aflow/flow.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import (
2222
// Actions are nodes of the graph, and they consume/produce some named values
2323
// (input/output fields, and intermediate values consumed by other actions).
2424
type Flow struct {
25-
Name string // Empty for the main workflow for the workflow type.
26-
Root Action
25+
Name string // Empty for the main workflow for the workflow type.
26+
Model string // The default Gemini model name to execute this workflow.
27+
Root Action
2728

2829
*FlowType
2930
}
@@ -35,6 +36,12 @@ type FlowType struct {
3536
extractOutputs func(map[string]any) map[string]any
3637
}
3738

39+
// See https://ai.google.dev/gemini-api/docs/models
40+
const (
41+
BestExpensiveModel = "gemini-3-pro-preview"
42+
GoodBalancedModel = "gemini-3-flash-preview"
43+
)
44+
3845
var Flows = make(map[string]*Flow)
3946

4047
// Register a workflow type (characterized by Inputs and Outputs),
@@ -88,6 +95,7 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error {
8895
actions: make(map[string]bool),
8996
state: make(map[string]*varState),
9097
}
98+
ctx.requireNotEmpty(flow.Name, "Model", flow.Model)
9199
provideOutputs[Inputs](ctx, "flow inputs")
92100
flow.Root.verify(ctx)
93101
requireInputs[Outputs](ctx, "flow outputs")

pkg/aflow/flow/assessment/kcsan.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func init() {
2323
ai.WorkflowAssessmentKCSAN,
2424
"assess if a KCSAN report is about a benign race that only needs annotations or not",
2525
&aflow.Flow{
26+
Model: aflow.GoodBalancedModel,
2627
Root: &aflow.Pipeline{
2728
Actions: []aflow.Action{
2829
kernel.Checkout,

pkg/aflow/flow/assessment/moderation.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ func init() {
3333
ai.WorkflowModeration,
3434
"assess if a bug report is consistent and actionable or not",
3535
&aflow.Flow{
36+
Model: aflow.GoodBalancedModel,
3637
Root: &aflow.Pipeline{
3738
Actions: []aflow.Action{
3839
aflow.NewFuncAction("extract-crash-type", extractCrashType),

pkg/aflow/flow/patching/patching.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func init() {
4343
ai.WorkflowPatching,
4444
"generate a kernel patch fixing a provided bug reproducer",
4545
&aflow.Flow{
46+
Model: aflow.BestExpensiveModel,
4647
Root: &aflow.Pipeline{
4748
Actions: []aflow.Action{
4849
baseCommitPicker,

pkg/aflow/flow_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ func TestWorkflow(t *testing.T) {
7676
flows := make(map[string]*Flow)
7777
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
7878
{
79-
Name: "flow",
79+
Name: "flow",
80+
Model: "model",
8081
Root: NewPipeline(
8182
NewFuncAction("func-action",
8283
func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
@@ -530,6 +531,7 @@ func TestNoInputs(t *testing.T) {
530531
flows := make(map[string]*Flow)
531532
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
532533
{
534+
Model: "model",
533535
Root: NewFuncAction("func-action",
534536
func(ctx *Context, args flowInputs) (flowOutputs, error) {
535537
return flowOutputs{}, nil

0 commit comments

Comments
 (0)