Skip to content

Commit 9d656e8

Browse files
committed
Query timeout for LLM queries to be able to abort long queries automatically
Closes #443
1 parent be241d3 commit 9d656e8

File tree

5 files changed

+67
-6
lines changed

5 files changed

+67
-6
lines changed

cmd/eval-dev-quality/cmd/evaluate.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ type Evaluate struct {
6060
// ProviderUrls holds all custom inference endpoint urls for the providers.
6161
ProviderUrls map[string]string `long:"urls" description:"Custom OpenAI API compatible inference endpoints (of the form '$provider:$url,...'). Use '$provider=custom-$name' to manually register a custom OpenAI API endpoint provider. Note that the models of a custom OpenAI API endpoint provider must be declared explicitly using the '--model' option. When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_URL" env-delim:","`
6262
// QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
63-
QueryAttempts uint `long:"attempts" description:"Number of query attempts to perform when a model request errors in the process of solving a task." default:"3"`
63+
QueryAttempts uint `long:"query-attempts" description:"Number of query attempts to perform when a model request errors in the process of solving a task." default:"3"`
64+
// QueryTimeout holds the timeout for model requests.
65+
QueryTimeout uint `long:"query-timeout" description:"Timeout of a model query in seconds. ("0" to disable)" default:"1200"`
6466

6567
// Repositories determines which repository should be used for the evaluation, or empty if all repositories should be used.
6668
Repositories []string `long:"repository" description:"Evaluate with this repository. By default all repositories are used."`
@@ -170,6 +172,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
170172
command.logger.Panicf("number of configured query attempts must be greater than zero")
171173
}
172174
evaluationContext.QueryAttempts = command.QueryAttempts
175+
evaluationContext.QueryTimeout = command.QueryTimeout
173176

174177
if command.ExecutionTimeout == 0 {
175178
command.logger.Panicf("execution timeout for compilation and tests must be greater than zero")

evaluate/evaluate.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type Context struct {
3030
ProviderForModel map[evalmodel.Model]provider.Provider
3131
// QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
3232
QueryAttempts uint
33+
// QueryTimeout holds the timeout for model queries in seconds.
34+
QueryTimeout uint
3335

3436
// RepositoryPaths determines which relative repository paths should be used for the evaluation, or empty if all repositories should be used.
3537
RepositoryPaths []string
@@ -130,8 +132,9 @@ func Evaluate(ctx *Context) {
130132
modelSucceededBasicChecksOfLanguage[model] = map[evallanguage.Language]bool{}
131133
}
132134

133-
if r, ok := model.(evalmodel.SetQueryAttempts); ok {
135+
if r, ok := model.(evalmodel.SetQueryHandling); ok {
134136
r.SetQueryAttempts(ctx.QueryAttempts)
137+
r.SetQueryTimeout(ctx.QueryTimeout)
135138
}
136139

137140
for _, taskIdentifier := range temporaryRepository.Configuration().Tasks {

model/llm/llm.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type Model struct {
3535
attributes map[string]string
3636
// queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
3737
queryAttempts uint
38+
// queryTimeout holds the timeout for model requests in seconds.
39+
queryTimeout uint
3840

3941
// metaInformation holds a model meta information.
4042
metaInformation *model.MetaInformation
@@ -47,6 +49,7 @@ func NewModel(provider provider.Query, modelIDWithAttributes string) (llmModel *
4749
provider: provider,
4850

4951
queryAttempts: 1,
52+
queryTimeout: 0,
5053
}
5154
llmModel.modelID, llmModel.attributes = model.ParseModelID(modelIDWithAttributes)
5255

@@ -61,6 +64,7 @@ func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string
6164
modelID: modelIdentifier,
6265

6366
queryAttempts: 1,
67+
queryTimeout: 0,
6468

6569
metaInformation: metaInformation,
6670
}
@@ -333,10 +337,19 @@ func (m *Model) query(logger *log.Logger, request string) (queryResult *provider
333337
if err := retry.Do(
334338
func() error {
335339
logger.Info("querying model", "model", m.ID(), "query-id", id, "prompt", string(bytesutil.PrefixLines([]byte(request), []byte("\t"))))
340+
ctx := context.Background()
341+
if m.queryTimeout > 0 {
342+
c, cancel := context.WithTimeoutCause(ctx, time.Second*time.Duration(m.queryTimeout), pkgerrors.Errorf("request query timeout (%d seconds)", m.queryTimeout))
343+
defer cancel()
344+
ctx = c
345+
}
346+
336347
start := time.Now()
337-
queryResult, err = m.provider.Query(context.Background(), m, request)
348+
queryResult, err = m.provider.Query(ctx, m, request)
338349
if err != nil {
339350
return err
351+
} else if ctx.Err() != nil {
352+
return context.Cause(ctx)
340353
}
341354
duration = time.Since(start)
342355
totalCosts := float64(-1)
@@ -520,9 +533,14 @@ func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute strin
520533
return assessment, nil
521534
}
522535

523-
var _ model.SetQueryAttempts = (*Model)(nil)
536+
var _ model.SetQueryHandling = (*Model)(nil)
524537

525538
// SetQueryAttempts sets the number of query attempts to perform when a model request errors in the process of solving a task.
526539
func (m *Model) SetQueryAttempts(queryAttempts uint) {
527540
m.queryAttempts = queryAttempts
528541
}
542+
543+
// SetQueryTimeout sets the timeout for model requests in seconds.
544+
func (m *Model) SetQueryTimeout(timeout uint) {
545+
m.queryTimeout = timeout
546+
}

model/llm/llm_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ func TestModelQuery(t *testing.T) {
173173
SetupMock func(mockedProvider *providertesting.MockQuery)
174174

175175
QueryAttempts uint
176+
QueryTimeout uint
176177
Request string
177178

178179
ExpectedResponse *provider.QueryResult
@@ -196,6 +197,7 @@ func TestModelQuery(t *testing.T) {
196197
}
197198
llm := NewModel(mock, "some-model")
198199
llm.SetQueryAttempts(tc.QueryAttempts)
200+
llm.SetQueryTimeout(tc.QueryTimeout)
199201

200202
queryResult, actualError := llm.query(logger, tc.Request)
201203

@@ -277,6 +279,39 @@ func TestModelQuery(t *testing.T) {
277279

278280
ValidateLogs: assertAllIDsMatch,
279281
})
282+
283+
validate(t, &testCase{
284+
Name: "Timeout",
285+
SetupMock: func(mockedProvider *providertesting.MockQuery) {
286+
queryResult := &provider.QueryResult{
287+
Message: "test response",
288+
}
289+
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
290+
},
291+
QueryAttempts: 1,
292+
QueryTimeout: 1,
293+
Request: "test request",
294+
ExpectedError: "request query timeout",
295+
})
296+
297+
validate(t, &testCase{
298+
Name: "Multiple Timeouts",
299+
SetupMock: func(mockedProvider *providertesting.MockQuery) {
300+
queryResult := &provider.QueryResult{
301+
Message: "test response",
302+
}
303+
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
304+
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
305+
},
306+
QueryAttempts: 2,
307+
QueryTimeout: 1,
308+
Request: "test request",
309+
ExpectedError: "request query timeout",
310+
311+
ValidateLogs: func(t *testing.T, logs string) {
312+
assert.Equal(t, 2, strings.Count(logs, "querying model"))
313+
},
314+
})
280315
}
281316

282317
func TestModelRepairSourceCodeFile(t *testing.T) {

model/model.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ type Context struct {
8282
Logger *log.Logger
8383
}
8484

85-
// SetQueryAttempts defines a model that can set the number of query attempts when a model request errors in the process of solving a task.
86-
type SetQueryAttempts interface {
85+
// SetQueryHandling defines a model that can configure how API queries are handled.
86+
type SetQueryHandling interface {
8787
// SetQueryAttempts sets the number of query attempts to perform when a model request errors in the process of solving a task.
8888
SetQueryAttempts(attempts uint)
89+
// SetQueryTimeout sets the timeout for model requests in seconds.
90+
SetQueryTimeout(timeout uint)
8991
}

0 commit comments

Comments
 (0)