|
1 | 1 | import typing as T |
2 | | -from .prompter import Prompter, CodeSegment, mk_message, get_module_name |
3 | | -from ..utils import lines_branches_do |
| 2 | +from .prompter import * |
| 3 | +import coverup.codeinfo as codeinfo |
4 | 4 |
|
5 | 5 |
|
6 | 6 | class ClaudePrompter(Prompter): |
7 | | - """Prompter for Claude.""" |
| 7 | + """Prompter tuned for Claude 3 Sonnet.""" |
8 | 8 |
|
9 | 9 | def __init__(self, *args, **kwargs): |
10 | 10 | super().__init__(*args, **kwargs) |
11 | 11 |
|
| 12 | + def system_prompt(self) -> T.List[dict]: |
| 13 | + """Optional: prepend this message when using Sonnet for better reliability.""" |
| 14 | + return [mk_message( |
| 15 | + "You are a code generator that writes complete Python test files. " |
| 16 | + "You must respond only with valid Python code enclosed in triple backticks, and nothing else. " |
| 17 | + "Do not include explanations or commentary." |
| 18 | + )] |
12 | 19 |
|
13 | 20 | def initial_prompt(self, segment: CodeSegment) -> T.List[dict]: |
14 | | - module_name = get_module_name(segment.path, self.args.src_base_dir) |
15 | 21 | filename = segment.path.relative_to(self.args.src_base_dir) |
16 | 22 |
|
17 | 23 | return [ |
18 | | - mk_message("You are an expert Python test-driven developer who creates pytest test functions that achieve high coverage.", |
19 | | - role="system"), |
| 24 | + *self.system_prompt(), |
20 | 25 | mk_message(f""" |
21 | | -<file path="{filename}" module_name="{module_name}"> |
22 | | -{segment.get_excerpt(tag_lines=bool(segment.executed_lines))} |
23 | | -</file> |
| 26 | +You are an expert Python test-driven developer. |
24 | 27 |
|
25 | | -<instructions> |
| 28 | +The following code, extracted from {filename}, does not achieve full coverage: |
| 29 | +{segment.lines_branches_missing_do()} do not execute. |
26 | 30 |
|
27 | | -The code above does not achieve full coverage: |
28 | | -when tested, {'it does' if not segment.executed_lines else segment.lines_branches_missing_do()} not execute. |
| 31 | +Your task: |
| 32 | +- Write new **pytest test functions** that cause all missing lines and branches to execute. |
| 33 | +- Tests must be correct and include assertions that verify postconditions. |
| 34 | +- If necessary, use the `get_info` tool function to learn more about symbols. |
| 35 | +- Ensure each test leaves no state behind; use `monkeypatch` or `pytest-mock` if helpful. |
| 36 | +- Do NOT include any top-level code that calls `pytest.main` or the test itself. |
29 | 37 |
|
30 | | -1. Create a new pytest test function that executes these missing lines/branches, always making |
31 | | -sure that the new test is correct and indeed improves coverage. |
| 38 | +Respond with **only the full Python test file**, enclosed in triple backticks. |
32 | 39 |
|
33 | | -2. Always send entire Python test scripts when proposing a new test or correcting one you |
34 | | -previously proposed. |
| 40 | +Here is the code to test: |
35 | 41 |
|
36 | | -3. Be sure to include assertions in the test that verify any applicable postconditions. |
| 42 | +```python |
| 43 | +{segment.get_excerpt()} |
| 44 | +``` |
| 45 | +""") |
| 46 | + ] |
37 | 47 |
|
38 | | -4. Please also make VERY SURE to clean up after the test, so as not to affect other tests; |
39 | | -use 'pytest-mock' if appropriate. |
40 | 48 |
|
41 | | -5. Write as little top-level code as possible, and in particular do not include any top-level code |
42 | | -calling into pytest.main or the test itself. |
| 49 | + def error_prompt(self, segment: CodeSegment, error: str) -> T.List[dict] | None: |
| 50 | + return [ |
| 51 | + *self.system_prompt(), |
| 52 | + mk_message(f"""\ |
| 53 | +The test produced an error: |
43 | 54 |
|
44 | | -6. Respond with the Python code enclosed in backticks. Before answering the question, please think about it step-by-step within <thinking></thinking> tags. Then, provide your final answer within <answer></answer> tags. |
45 | | -</instructions> |
46 | | -""") |
47 | | - ] |
| 55 | +{error} |
48 | 56 |
|
| 57 | +Please revise the test to correct the error. |
49 | 58 |
|
50 | | - def error_prompt(self, segment: CodeSegment, error: str) -> T.List[dict]: |
51 | | - return [mk_message(f"""\ |
52 | | -<error>{error}</error> |
53 | | -Executing the test yields an error, shown above. |
54 | | -<instructions> |
55 | | -1. Modify the test to correct it. |
56 | | -2. Respond with the complete Python code in backticks. |
57 | | -3. Before answering the question, please think about it step-by-step within <thinking></thinking> tags. Then, provide your final answer within <answer></answer> tags. |
58 | | -</instructions> |
| 59 | +Respond with only the complete revised Python test file, enclosed in triple backticks. |
| 60 | +You may use the `get_info` tool function if needed. |
59 | 61 | """) |
60 | 62 | ] |
61 | 63 |
|
62 | 64 |
|
63 | 65 | def missing_coverage_prompt(self, segment: CodeSegment, |
64 | | - missing_lines: set, missing_branches: set) -> T.List[dict]: |
65 | | - return [mk_message(f"""\ |
66 | | -This test still lacks coverage: {lines_branches_do(missing_lines, set(), missing_branches)} not execute. |
67 | | -<instructions> |
68 | | -1. Modify it to execute those lines. |
69 | | -2. Respond with the complete Python code in backticks. |
70 | | -3. Before responding, please think about it step-by-step within <thinking></thinking> tags. Then, provide your final answer within <answer></answer> tags. |
71 | | -</instructions> |
| 66 | + missing_lines: set, missing_branches: set) -> T.List[dict] | None: |
| 67 | + return [ |
| 68 | + *self.system_prompt(), |
| 69 | + mk_message(f"""\ |
| 70 | +The test still lacks coverage: {lines_branches_do(missing_lines, set(), missing_branches)} do not execute. |
| 71 | +
|
| 72 | +Revise the test to ensure full coverage. |
| 73 | +
|
| 74 | +Respond with only the complete revised Python test file, enclosed in triple backticks. |
| 75 | +You may use the `get_info` tool function if helpful. |
72 | 76 | """) |
73 | 77 | ] |
| 78 | + |
| 79 | + |
| 80 | + def get_info(self, ctx: CodeSegment, name: str) -> str: |
| 81 | + """ |
| 82 | + { |
| 83 | + "name": "get_info", |
| 84 | + "description": "Returns information about a symbol.", |
| 85 | + "parameters": { |
| 86 | + "type": "object", |
| 87 | + "properties": { |
| 88 | + "name": { |
| 89 | + "type": "string", |
| 90 | + "description": "class, function or method name, as in 'f' for function f or 'C.foo' for method foo in class C." |
| 91 | + } |
| 92 | + }, |
| 93 | + "required": ["name"] |
| 94 | + } |
| 95 | + } |
| 96 | + """ |
| 97 | + |
| 98 | + if info := codeinfo.get_info(codeinfo.parse_file(ctx.path), name, line=ctx.begin): |
| 99 | + return "\"...\" below indicates omitted code.\n\n" + info |
| 100 | + |
| 101 | + return f"Unable to obtain information on {name}." |
| 102 | + |
| 103 | + |
| 104 | + def get_functions(self) -> T.List[T.Callable]: |
| 105 | + return [self.get_info] |
0 commit comments