Skip to content

Commit b9ce620

Browse files
committed
refactor, Do mocking setup only if test is really executed so mockito does not complain about "missing calls" when selectively targeting other tests
1 parent a9980c1 commit b9ce620

File tree

3 files changed

+45
-30
lines changed

3 files changed

+45
-30
lines changed

evaluate/task/test-integration/task_test.go

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,28 +102,31 @@ func TestWriteTestsRun(t *testing.T) {
102102
})
103103
{
104104
mockProvider := providertesting.NewMockQuery(t)
105-
mockProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(
106-
&provider.QueryResult{
107-
Message: bytesutil.StringTrimIndentations(`
108-
` + "```rust`" + `
109-
#[cfg(test)]
110-
mod tests {
111-
use super::*;
112-
113-
#[test]
114-
fn test_plain() {
115-
plain();
116-
}
117-
}
118-
` + "```" + `
119-
`),
120-
},
121-
nil,
122-
).After(100 * time.Millisecond)
123105
model := llm.NewModel(mockProvider, "model")
124106
validate(t, &tasktesting.TestCaseTask{
125107
Name: "Rust",
126108

109+
Setup: func(t *testing.T) {
110+
mockProvider.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
111+
&provider.QueryResult{
112+
Message: bytesutil.StringTrimIndentations(`
113+
` + "```rust`" + `
114+
#[cfg(test)]
115+
mod tests {
116+
use super::*;
117+
118+
#[test]
119+
fn test_plain() {
120+
plain();
121+
}
122+
}
123+
` + "```" + `
124+
`),
125+
},
126+
nil,
127+
).After(100 * time.Millisecond)
128+
},
129+
127130
Model: model,
128131
Language: &rust.Language{},
129132
TestDataPath: filepath.Join("..", "..", "..", "testdata"),

evaluate/task/testing/task.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
type TestCaseTask struct {
2323
Name string
2424

25+
Setup func(t *testing.T)
26+
2527
Task evaltask.Task
2628
Model model.Model
2729
Language language.Language
@@ -51,6 +53,10 @@ func (tc *TestCaseTask) Validate(t *testing.T, createRepository createRepository
5153
assert.NoError(t, err)
5254
defer cleanup()
5355

56+
if tc.Setup != nil {
57+
tc.Setup(t)
58+
}
59+
5460
taskContext := evaltask.Context{
5561
Language: tc.Language,
5662
Repository: repository,

evaluate/task/write-test_test.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,15 @@ func TestWriteTestsRun(t *testing.T) {
130130
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath))
131131

132132
modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
133-
// Simulate that a model does not generate anything.
134-
modelMock.MockCapabilityWriteTests.On("WriteTests", mock.Anything).Return(metricstesting.AssessmentsWithProcessingTime, nil)
135133

136134
validate(t, &tasktesting.TestCaseTask{
137135
Name: "Reset symflower template so it's not mistaken for model solution",
138136

137+
Setup: func(t *testing.T) {
138+
// Simulate that a model does not generate anything.
139+
modelMock.MockCapabilityWriteTests.On("WriteTests", mock.Anything).Return(metricstesting.AssessmentsWithProcessingTime, nil)
140+
},
141+
139142
Model: modelMock,
140143
Language: &golang.Language{},
141144
TestDataPath: temporaryDirectoryPath,
@@ -389,19 +392,22 @@ func TestWriteTestsRun(t *testing.T) {
389392
// There will be no template for an empty file.
390393
`)), 0666))
391394
modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
392-
modelMock.RegisterGenerateSuccess(t, "empty_test.go", "package plain\n", metricstesting.AssessmentsWithProcessingTime).Once()
393-
modelMock.RegisterGenerateSuccess(t, "plain_test.go", bytesutil.StringTrimIndentations(`
394-
package plain
395-
396-
import "testing"
397-
398-
func TestPlain(t *testing.T) {
399-
plain()
400-
}
401-
`), metricstesting.AssessmentsWithProcessingTime)
402395
validate(t, &tasktesting.TestCaseTask{
403396
Name: "Keep non-template score if template fails",
404397

398+
Setup: func(t *testing.T) {
399+
modelMock.RegisterGenerateSuccess(t, "empty_test.go", "package plain\n", metricstesting.AssessmentsWithProcessingTime).Once()
400+
modelMock.RegisterGenerateSuccess(t, "plain_test.go", bytesutil.StringTrimIndentations(`
401+
package plain
402+
403+
import "testing"
404+
405+
func TestPlain(t *testing.T) {
406+
plain()
407+
}
408+
`), metricstesting.AssessmentsWithProcessingTime)
409+
},
410+
405411
Model: modelMock,
406412
Language: &golang.Language{},
407413
TestDataPath: temporaryDirectoryPath,

0 commit comments

Comments
 (0)