|
4 | 4 | package aflow |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "errors" |
7 | 8 | "fmt" |
8 | 9 | "maps" |
| 10 | + "net/http" |
9 | 11 | "reflect" |
| 12 | + "time" |
10 | 13 |
|
11 | 14 | "github.com/google/syzkaller/pkg/aflow/trajectory" |
12 | 15 | "google.golang.org/genai" |
@@ -155,7 +158,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma |
155 | 158 | if err := ctx.startSpan(reqSpan); err != nil { |
156 | 159 | return "", nil, err |
157 | 160 | } |
158 | | - resp, err := ctx.generateContent(cfg, req) |
| 161 | + resp, err := a.generateContent(ctx, cfg, req) |
159 | 162 | if err != nil { |
160 | 163 | return "", nil, ctx.finishSpan(reqSpan, err) |
161 | 164 | } |
@@ -271,6 +274,22 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( |
271 | 274 | return |
272 | 275 | } |
273 | 276 |
|
| 277 | +func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig, |
| 278 | + req []*genai.Content) (*genai.GenerateContentResponse, error) { |
| 279 | + backoff := time.Second |
| 280 | + for try := 0; ; try++ { |
| 281 | + resp, err := ctx.generateContent(cfg, req) |
| 282 | + var apiErr genai.APIError |
| 283 | + if err != nil && try < 100 && errors.As(err, &apiErr) && |
| 284 | + apiErr.Code == http.StatusServiceUnavailable { |
| 285 | + time.Sleep(backoff) |
| 286 | + backoff = min(backoff+time.Second, 10*time.Second) |
| 287 | + continue |
| 288 | + } |
| 289 | + return resp, err |
| 290 | + } |
| 291 | +} |
| 292 | + |
274 | 293 | func (a *LLMAgent) verify(vctx *verifyContext) { |
275 | 294 | vctx.requireNotEmpty(a.Name, "Name", a.Name) |
276 | 295 | vctx.requireNotEmpty(a.Name, "Reply", a.Reply) |
|
0 commit comments