Skip to content

Commit 81bc339

Browse files
committed
Support "write-test" task for Rust
Requires special handling for languages that have their tests within implementation files Part of #448
1 parent 26235c0 commit 81bc339

File tree

6 files changed

+175
-11
lines changed

6 files changed

+175
-11
lines changed

cmd/eval-dev-quality/cmd/evaluate.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
_ "github.com/symflower/eval-dev-quality/language/golang" // Register language.
2828
_ "github.com/symflower/eval-dev-quality/language/java" // Register language.
2929
_ "github.com/symflower/eval-dev-quality/language/ruby" // Register language.
30+
_ "github.com/symflower/eval-dev-quality/language/rust" // Register language.
3031
"github.com/symflower/eval-dev-quality/log"
3132
"github.com/symflower/eval-dev-quality/model"
3233
"github.com/symflower/eval-dev-quality/model/llm"

evaluate/task/test-integration/task_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@ package testintegration
33
import (
44
"path/filepath"
55
"testing"
6+
"time"
67

78
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/mock"
810
"github.com/stretchr/testify/require"
911
"github.com/symflower/eval-dev-quality/evaluate/metrics"
1012
evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task"
1113
tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing"
1214
"github.com/symflower/eval-dev-quality/language/golang"
15+
"github.com/symflower/eval-dev-quality/language/rust"
1316
"github.com/symflower/eval-dev-quality/log"
17+
"github.com/symflower/eval-dev-quality/model/llm"
1418
"github.com/symflower/eval-dev-quality/model/symflower"
19+
"github.com/symflower/eval-dev-quality/provider"
20+
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
1521
evaltask "github.com/symflower/eval-dev-quality/task"
1622
"github.com/symflower/eval-dev-quality/tools"
1723
toolstesting "github.com/symflower/eval-dev-quality/tools/testing"
24+
"github.com/zimmski/osutil/bytesutil"
1825
)
1926

2027
func TestWriteTestsRun(t *testing.T) {
@@ -93,4 +100,82 @@ func TestWriteTestsRun(t *testing.T) {
93100
assert.Contains(t, data, "msg=\"evaluated model\" model=symflower/symbolic-execution")
94101
},
95102
})
103+
{
104+
mockProvider := providertesting.NewMockQuery(t)
105+
mockProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(
106+
&provider.QueryResult{
107+
Message: bytesutil.StringTrimIndentations(`
108+
` + "```rust`" + `
109+
#[cfg(test)]
110+
mod tests {
111+
use super::*;
112+
113+
#[test]
114+
fn test_plain() {
115+
plain();
116+
}
117+
}
118+
` + "```" + `
119+
`),
120+
},
121+
nil,
122+
).After(100 * time.Millisecond)
123+
model := llm.NewModel(mockProvider, "model")
124+
validate(t, &tasktesting.TestCaseTask{
125+
Name: "Rust",
126+
127+
Model: model,
128+
Language: &rust.Language{},
129+
TestDataPath: filepath.Join("..", "..", "..", "testdata"),
130+
RepositoryPath: filepath.Join("rust", "plain"),
131+
132+
ExpectedRepositoryAssessment: map[string]map[evaltask.Identifier]metrics.Assessments{
133+
filepath.Join("src", "lib.rs"): {
134+
evaluatetask.IdentifierWriteTests: metrics.Assessments{
135+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
136+
metrics.AssessmentKeyResponseCharacterCount: 98,
137+
metrics.AssessmentKeyCoverage: 0, // TODO Get coverage.
138+
metrics.AssessmentKeyFilesExecuted: 1,
139+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
140+
metrics.AssessmentKeyResponseNoError: 1,
141+
metrics.AssessmentKeyResponseNoExcess: 1,
142+
metrics.AssessmentKeyResponseWithCode: 1,
143+
},
144+
evaluatetask.IdentifierWriteTestsSymflowerFix: metrics.Assessments{
145+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
146+
metrics.AssessmentKeyResponseCharacterCount: 98,
147+
metrics.AssessmentKeyCoverage: 0, // TODO Get coverage.
148+
metrics.AssessmentKeyFilesExecuted: 1,
149+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
150+
metrics.AssessmentKeyResponseNoError: 1,
151+
metrics.AssessmentKeyResponseNoExcess: 1,
152+
metrics.AssessmentKeyResponseWithCode: 1,
153+
},
154+
evaluatetask.IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
155+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
156+
metrics.AssessmentKeyResponseCharacterCount: 98,
157+
metrics.AssessmentKeyCoverage: 0, // TODO Get coverage.
158+
metrics.AssessmentKeyFilesExecuted: 1,
159+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
160+
metrics.AssessmentKeyResponseNoError: 1,
161+
metrics.AssessmentKeyResponseNoExcess: 1,
162+
metrics.AssessmentKeyResponseWithCode: 1,
163+
},
164+
evaluatetask.IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
165+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
166+
metrics.AssessmentKeyResponseCharacterCount: 98,
167+
metrics.AssessmentKeyCoverage: 0, // TODO Get coverage.
168+
metrics.AssessmentKeyFilesExecuted: 1,
169+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
170+
metrics.AssessmentKeyResponseNoError: 1,
171+
metrics.AssessmentKeyResponseNoExcess: 1,
172+
metrics.AssessmentKeyResponseWithCode: 1,
173+
},
174+
},
175+
},
176+
ValidateLog: func(t *testing.T, data string) {
177+
assert.Contains(t, data, "test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out;")
178+
},
179+
})
180+
}
96181
}

evaluate/task/write-test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ func validateWriteTestsRepository(logger *log.Logger, repositoryPath string, lan
175175
var sourceFiles []string
176176
var testFiles []string
177177
for _, file := range files {
178-
if strings.HasSuffix(file, language.DefaultTestFileSuffix()) {
179-
testFiles = append(testFiles, file)
180-
} else if strings.HasSuffix(file, language.DefaultFileExtension()) {
178+
if strings.HasSuffix(file, language.DefaultFileExtension()) { // For languages where source file == test file, assume we are collecting source files by default.
181179
sourceFiles = append(sourceFiles, file)
180+
} else if strings.HasSuffix(file, language.DefaultTestFileSuffix()) {
181+
testFiles = append(testFiles, file)
182182
}
183183
}
184184

language/rust/language.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func (l *Language) ExecuteTests(logger *log.Logger, repositoryPath string) (test
7676
commandOutput, err := util.CommandWithResult(context.Background(), logger, &util.Command{
7777
Command: []string{ // TODO Move this to `symflower test` to get coverage information.
7878
"cargo",
79-
"test",
79+
"cargo-llvm-cov",
8080
},
8181

8282
Directory: repositoryPath,

model/llm/llm.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package llm
22

33
import (
44
"context"
5+
"errors"
56
"os"
67
"path/filepath"
78
"strings"
@@ -145,7 +146,7 @@ type llmWriteTestSourceFilePromptContext struct {
145146

146147
// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
147148
var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
148-
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with .TestFramework }} with {{ . }} as a test framework{{ end }}.
149+
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" {{- with .ImportPath }} with package "{{ . }}" {{- end }}, provide {{- if .Language.HasTestsInSource }} tests {{ else }} a test file {{ end -}} for this code{{ with .TestFramework }} with {{ . }} as a test framework{{ end }}.
149150
The tests should produce 100 percent code coverage and must compile.
150151
The response must contain only the test code in a fenced code block and nothing else.
151152
@@ -328,7 +329,7 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
328329

329330
filePath := filepath.Join(ctx.RepositoryPath, ctx.Language.TestFilePath(ctx.RepositoryPath, ctx.FilePath))
330331

331-
return handleQueryResult(queryResult, filePath)
332+
return handleQueryResult(queryResult, filePath, ctx.Language.HasTestsInSource())
332333
}
333334

334335
func (m *Model) query(logger *log.Logger, request string) (queryResult *provider.QueryResult, err error) {
@@ -413,7 +414,7 @@ func (m *Model) RepairCode(ctx model.Context) (assessment metrics.Assessments, e
413414
return nil, pkgerrors.WithStack(err)
414415
}
415416

416-
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath))
417+
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath), false)
417418
}
418419

419420
var _ model.CapabilityTranspile = (*Model)(nil)
@@ -460,7 +461,7 @@ func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, er
460461
return nil, pkgerrors.WithStack(err)
461462
}
462463

463-
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath))
464+
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath), false)
464465
}
465466

466467
var _ model.CapabilityMigrate = (*Model)(nil)
@@ -500,10 +501,10 @@ func (m *Model) Migrate(ctx model.Context) (assessment metrics.Assessments, err
500501
return nil, pkgerrors.WithStack(err)
501502
}
502503

503-
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath))
504+
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath), false)
504505
}
505506

506-
func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute string) (assessment metrics.Assessments, err error) {
507+
func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute string, appendFile bool) (assessment metrics.Assessments, err error) {
507508
assessment, sourceFileContent, err := prompt.ParseResponse(queryResult.Message)
508509
if err != nil {
509510
return nil, pkgerrors.WithStack(err)
@@ -526,7 +527,21 @@ func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute strin
526527
if err := os.MkdirAll(filepath.Dir(filePathAbsolute), 0755); err != nil {
527528
return nil, pkgerrors.WithStack(err)
528529
}
529-
if err := os.WriteFile(filePathAbsolute, []byte(sourceFileContent), 0644); err != nil {
530+
531+
flags := os.O_WRONLY | os.O_CREATE
532+
if appendFile {
533+
flags = flags | os.O_APPEND
534+
}
535+
file, err := os.OpenFile(filePathAbsolute, flags, 0644)
536+
if err != nil {
537+
return nil, pkgerrors.WithStack(err)
538+
}
539+
defer func() {
540+
if closeErr := file.Close(); closeErr != nil {
541+
err = errors.Join(err, pkgerrors.WithStack(closeErr))
542+
}
543+
}()
544+
if _, err := file.WriteString(sourceFileContent); err != nil {
530545
return nil, pkgerrors.WithStack(err)
531546
}
532547

model/llm/llm_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/symflower/eval-dev-quality/language"
2222
"github.com/symflower/eval-dev-quality/language/golang"
2323
"github.com/symflower/eval-dev-quality/language/java"
24+
"github.com/symflower/eval-dev-quality/language/rust"
2425
"github.com/symflower/eval-dev-quality/log"
2526
"github.com/symflower/eval-dev-quality/model"
2627
"github.com/symflower/eval-dev-quality/provider"
@@ -654,6 +655,68 @@ func TestFormatPromptContext(t *testing.T) {
654655
` + "```" + `
655656
`),
656657
})
658+
659+
validate(t, &testCase{
660+
Name: "No Import path",
661+
662+
Context: &llmWriteTestSourceFilePromptContext{
663+
llmSourceFilePromptContext: llmSourceFilePromptContext{
664+
Language: &golang.Language{},
665+
666+
Code: bytesutil.StringTrimIndentations(`
667+
package increment
668+
669+
func increment(i int) int
670+
return i + 1
671+
}
672+
`),
673+
FilePath: filepath.Join("path", "to", "increment.go"),
674+
ImportPath: "",
675+
},
676+
},
677+
678+
ExpectedMessage: bytesutil.StringTrimIndentations(`
679+
Given the following Go code file "path/to/increment.go", provide a test file for this code.
680+
The tests should produce 100 percent code coverage and must compile.
681+
The response must contain only the test code in a fenced code block and nothing else.
682+
683+
` + "```" + `golang
684+
package increment
685+
686+
func increment(i int) int
687+
return i + 1
688+
}
689+
` + "```" + `
690+
`),
691+
})
692+
693+
validate(t, &testCase{
694+
Name: "Tests in source file",
695+
696+
Context: &llmWriteTestSourceFilePromptContext{
697+
llmSourceFilePromptContext: llmSourceFilePromptContext{
698+
Language: &rust.Language{},
699+
700+
Code: bytesutil.StringTrimIndentations(`
701+
fn main() {
702+
}
703+
`),
704+
FilePath: filepath.Join("path", "to", "main.rs"),
705+
ImportPath: "",
706+
},
707+
},
708+
709+
ExpectedMessage: bytesutil.StringTrimIndentations(`
710+
Given the following Rust code file "path/to/main.rs", provide tests for this code.
711+
The tests should produce 100 percent code coverage and must compile.
712+
The response must contain only the test code in a fenced code block and nothing else.
713+
714+
` + "```" + `rust
715+
fn main() {
716+
}
717+
` + "```" + `
718+
`),
719+
})
657720
})
658721

659722
validate(t, &testCase{

0 commit comments

Comments
 (0)