Skip to content

Commit f6e2803

Browse files
committed
refactor, Make logger available within the provider query
1 parent 8103a0d commit f6e2803

File tree

8 files changed

+40
-38
lines changed

8 files changed

+40
-38
lines changed

evaluate/evaluate_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ func TestEvaluate(t *testing.T) {
181181
},
182182
}
183183
// Set up mocks, when test is running.
184-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult1, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
184+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult1, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
185185

186186
queryResult2 := &provider.QueryResult{
187187
Message: "",
@@ -192,7 +192,7 @@ func TestEvaluate(t *testing.T) {
192192
},
193193
}
194194
// Set up mocks, when test is running.
195-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult2, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
195+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult2, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
196196
},
197197
After: func(t *testing.T, logger *log.Logger, resultPath string) {
198198
mockedQuery.AssertNumberOfCalls(t, "Query", 2)
@@ -296,7 +296,7 @@ func TestEvaluate(t *testing.T) {
296296

297297
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
298298
// Set up mocks, when test is running.
299-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel)
299+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel)
300300
},
301301
After: func(t *testing.T, logger *log.Logger, resultPath string) {
302302
mockedQuery.AssertNumberOfCalls(t, "Query", 2)
@@ -380,10 +380,10 @@ func TestEvaluate(t *testing.T) {
380380
Message: "model-response",
381381
}
382382
// Set up mocks, when test is running.
383-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel).Once()
384-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
385-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel).Once()
386-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
383+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel).Once()
384+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
385+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel).Once()
386+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
387387
},
388388
After: func(t *testing.T, logger *log.Logger, resultPath string) {
389389
mockedQuery.AssertNumberOfCalls(t, "Query", 4)
@@ -486,7 +486,7 @@ func TestEvaluate(t *testing.T) {
486486
},
487487
}
488488
// Set up mocks, when test is running.
489-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
489+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
490490
},
491491
After: func(t *testing.T, logger *log.Logger, resultPath string) {
492492
mockedQuery.AssertNumberOfCalls(t, "Query", 2)

model/llm/llm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ func (m *Model) query(logger *log.Logger, request string) (queryResult *provider
345345
}
346346

347347
start := time.Now()
348-
queryResult, err = m.provider.Query(ctx, m, request)
348+
queryResult, err = m.provider.Query(ctx, logger, m, request)
349349
if err != nil {
350350
return err
351351
} else if ctx.Err() != nil {

model/llm/llm_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func TestModelGenerateTestsForFile(t *testing.T) {
120120
` + "```" + `
121121
`),
122122
}
123-
mockedProvider.On("Query", mock.Anything, mock.Anything, promptMessage).Return(queryResult, nil)
123+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, promptMessage).Return(queryResult, nil)
124124
},
125125

126126
Language: &golang.Language{},
@@ -151,7 +151,7 @@ func TestModelGenerateTestsForFile(t *testing.T) {
151151
TotalCost: 0.123456789,
152152
},
153153
}
154-
mockedProvider.On("Query", mock.Anything, mock.Anything, promptMessage).Return(queryResult, nil)
154+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, promptMessage).Return(queryResult, nil)
155155
},
156156

157157
Language: &golang.Language{},
@@ -242,7 +242,7 @@ func TestModelQuery(t *testing.T) {
242242
queryResult := &provider.QueryResult{
243243
Message: "test response",
244244
}
245-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil)
245+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(queryResult, nil)
246246
},
247247
QueryAttempts: 1,
248248
Request: "test request",
@@ -256,7 +256,7 @@ func TestModelQuery(t *testing.T) {
256256
validate(t, &testCase{
257257
Name: "Failed query no retry",
258258
SetupMock: func(mockedProvider *providertesting.MockQuery) {
259-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(nil, assert.AnError)
259+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(nil, assert.AnError)
260260
},
261261
QueryAttempts: 1,
262262
Request: "test request",
@@ -266,8 +266,8 @@ func TestModelQuery(t *testing.T) {
266266
validate(t, &testCase{
267267
Name: "Failed query with retry",
268268
SetupMock: func(mockedProvider *providertesting.MockQuery) {
269-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(nil, assert.AnError).Once()
270-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(&provider.QueryResult{
269+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(nil, assert.AnError).Once()
270+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(&provider.QueryResult{
271271
Message: "test response",
272272
}, nil).Once()
273273
},
@@ -286,7 +286,7 @@ func TestModelQuery(t *testing.T) {
286286
queryResult := &provider.QueryResult{
287287
Message: "test response",
288288
}
289-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
289+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
290290
},
291291
QueryAttempts: 1,
292292
QueryTimeout: 1,
@@ -300,8 +300,8 @@ func TestModelQuery(t *testing.T) {
300300
queryResult := &provider.QueryResult{
301301
Message: "test response",
302302
}
303-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
304-
mockedProvider.On("Query", mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
303+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
304+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, "test request").Return(queryResult, nil).After(time.Second * 2)
305305
},
306306
QueryAttempts: 2,
307307
QueryTimeout: 1,
@@ -392,7 +392,7 @@ func TestModelRepairSourceCodeFile(t *testing.T) {
392392
` + "```" + `
393393
`),
394394
}
395-
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
395+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
396396
},
397397

398398
Language: &golang.Language{},
@@ -446,7 +446,7 @@ func TestModelRepairSourceCodeFile(t *testing.T) {
446446
` + "```" + `
447447
`),
448448
}
449-
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
449+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
450450
},
451451

452452
Language: &java.Language{},
@@ -876,7 +876,7 @@ func TestModelTranspile(t *testing.T) {
876876
queryResult := &provider.QueryResult{
877877
Message: "```\n" + transpiledFileContent + "```\n",
878878
}
879-
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
879+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
880880
},
881881

882882
Language: &golang.Language{},
@@ -929,7 +929,7 @@ func TestModelTranspile(t *testing.T) {
929929
queryResult := &provider.QueryResult{
930930
Message: "```\n" + transpiledFileContent + "```\n",
931931
}
932-
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
932+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
933933
},
934934

935935
Language: &java.Language{},
@@ -1032,7 +1032,7 @@ func TestModelMigrate(t *testing.T) {
10321032
queryResult := &provider.QueryResult{
10331033
Message: "```\n" + migratedTestFile + "```\n",
10341034
}
1035-
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
1035+
mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil)
10361036
},
10371037

10381038
Language: &java.Language{},

provider/ollama/ollama.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (p *Provider) Models() (models []model.Model, err error) {
8181
var _ provider.Query = (*Provider)(nil)
8282

8383
// Query queries the provider with the given model name.
84-
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (result *provider.QueryResult, err error) {
84+
func (p *Provider) Query(ctx context.Context, logger *log.Logger, model model.Model, promptText string) (result *provider.QueryResult, err error) {
8585
return openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
8686
}
8787

provider/openai-api/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (p *Provider) SetToken(token string) {
5959
var _ provider.Query = (*Provider)(nil)
6060

6161
// Query queries the provider with the given model name.
62-
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (result *provider.QueryResult, err error) {
62+
func (p *Provider) Query(ctx context.Context, logger *log.Logger, model model.Model, promptText string) (result *provider.QueryResult, err error) {
6363
return QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
6464
}
6565

provider/openrouter/openrouter.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ func (p *Provider) SetToken(token string) {
137137
var _ provider.Query = (*Provider)(nil)
138138

139139
// Query queries the provider with the given model name.
140-
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (result *provider.QueryResult, err error) {
140+
func (p *Provider) Query(ctx context.Context, logger *log.Logger, model model.Model, promptText string) (result *provider.QueryResult, err error) {
141141
queryResult, err := openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
142142
if err != nil {
143143
return nil, pkgerrors.WithStack(err)
144144
}
145145

146-
queryResult.GenerationInfo, err = p.fetchGenerationInfo(queryResult.ResponseID)
146+
queryResult.GenerationInfo, err = p.fetchGenerationInfo(logger, queryResult.ResponseID)
147147
if err != nil {
148148
return nil, pkgerrors.WithStack(err)
149149
}
@@ -159,7 +159,7 @@ func (p *Provider) client() (client *openai.Client) {
159159
return openai.NewClientWithConfig(config)
160160
}
161161

162-
func (p *Provider) fetchGenerationInfo(generationID string) (generationInfo *provider.GenerationInfo, err error) {
162+
func (p *Provider) fetchGenerationInfo(logger *log.Logger, generationID string) (generationInfo *provider.GenerationInfo, err error) {
163163
request, err := http.NewRequest("GET", "https://openrouter.ai/api/v1/generation?id="+generationID, nil)
164164
if err != nil {
165165
return nil, pkgerrors.WithStack(err)

provider/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ type GenerationInfo struct {
7373
// Query is a provider that allows to query a model directly.
7474
type Query interface {
7575
// Query queries the provider with the given model name.
76-
Query(ctx context.Context, model model.Model, promptText string) (result *QueryResult, err error)
76+
Query(ctx context.Context, logger *log.Logger, model model.Model, promptText string) (result *QueryResult, err error)
7777
}
7878

7979
// Service is a provider that requires background services.

provider/testing/Query_mock_gen.go

Lines changed: 11 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)