Skip to content

Commit d894997

Browse files
155 add source and test file diffs in database logging (#156)
* Adding comments to top level CoverAgent class. * Refactored out init runs from UnitTestGenerator. * Added source file to DB. * Added language. * Refactored out DRY RUN and updated HTML template for DB insertion. * Updated report template. * Added report generation with poetry. * Refactored unit test calls. * Adding test analysis and summary report. * Updated doc and fixed test analysis logic.
1 parent 5df4c89 commit d894997

File tree

13 files changed

+470
-333
lines changed

13 files changed

+470
-333
lines changed

cover_agent/CoverAgent.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
class CoverAgent:
1313
def __init__(self, args):
14+
"""
15+
Initialize the CoverAgent class with the provided arguments and run the test generation process.
16+
17+
Parameters:
18+
args (Namespace): The parsed command-line arguments containing necessary information for test generation.
19+
20+
Returns:
21+
None
22+
"""
1423
self.args = args
1524
self.logger = CustomLogger.get_logger(__name__)
1625

@@ -33,49 +42,99 @@ def __init__(self, args):
3342
)
3443

3544
def _validate_paths(self):
45+
"""
46+
Validate the paths provided in the arguments.
47+
48+
Raises:
49+
FileNotFoundError: If the source file or test file is not found at the specified paths.
50+
"""
51+
# Ensure the source file exists
3652
if not os.path.isfile(self.args.source_file_path):
3753
raise FileNotFoundError(
3854
f"Source file not found at {self.args.source_file_path}"
3955
)
56+
# Ensure the test file exists
4057
if not os.path.isfile(self.args.test_file_path):
4158
raise FileNotFoundError(
4259
f"Test file not found at {self.args.test_file_path}"
4360
)
61+
# Create default DB file if not provided
4462
if not self.args.log_db_path:
45-
# Create default DB file if not provided
4663
self.args.log_db_path = "cover_agent_unit_test_runs.db"
64+
# Connect to the test DB
4765
self.test_db = UnitTestDB(db_connection_string=f"sqlite:///{self.args.log_db_path}")
4866

4967
def _duplicate_test_file(self):
68+
"""
69+
Initialize the CoverAgent class with the provided arguments and run the test generation process.
70+
71+
Parameters:
72+
args (Namespace): The parsed command-line arguments containing necessary information for test generation.
73+
74+
Returns:
75+
None
76+
"""
77+
# If the test file output path is set, copy the test file there
5078
if self.args.test_file_output_path != "":
5179
shutil.copy(self.args.test_file_path, self.args.test_file_output_path)
5280
else:
81+
# Otherwise, set the test file output path to the current test file
5382
self.args.test_file_output_path = self.args.test_file_path
5483

5584
def run(self):
85+
"""
86+
Run the test generation process.
87+
88+
This method performs the following steps:
89+
90+
1. Initialize the Weights & Biases run if the WANDS_API_KEY environment variable is set.
91+
2. Initialize variables to track progress.
92+
3. Run the initial test suite analysis.
93+
4. Loop until desired coverage is reached or maximum iterations are met.
94+
5. Generate new tests.
95+
6. Loop through each new test and validate it.
96+
7. Insert the test result into the database.
97+
8. Increment the iteration count.
98+
9. Check if the desired coverage has been reached.
99+
10. If the desired coverage has been reached, log the final coverage.
100+
11. If the maximum iteration limit is reached, log a failure message if strict coverage is specified.
101+
12. Provide metrics on total token usage.
102+
13. Generate a report.
103+
14. Finish the Weights & Biases run if it was initialized.
104+
"""
105+
# Check if user has exported the WANDS_API_KEY environment variable
56106
if "WANDB_API_KEY" in os.environ:
107+
# Initialize the Weights & Biases run
57108
wandb.login(key=os.environ["WANDB_API_KEY"])
58109
time_and_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
59110
run_name = f"{self.args.model}_" + time_and_date
60111
wandb.init(project="cover-agent", name=run_name)
61112

113+
# Initialize variables to track progress
62114
iteration_count = 0
63115
test_results_list = []
64116

117+
# Run initial test suite analysis
118+
self.test_gen.get_coverage_and_build_prompt()
65119
self.test_gen.initial_test_suite_analysis()
66120

121+
# Loop until desired coverage is reached or maximum iterations are met
67122
while (
68123
self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100)
69124
and iteration_count < self.args.max_iterations
70125
):
126+
# Log the current coverage
71127
self.logger.info(
72128
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
73129
)
74130
self.logger.info(f"Desired Coverage: {self.test_gen.desired_coverage}%")
75131

132+
# Generate new tests
76133
generated_tests_dict = self.test_gen.generate_tests(max_tokens=4096)
77134

135+
# Loop through each new test and validate it
78136
for generated_test in generated_tests_dict.get("new_tests", []):
137+
# Validate the test and record the result
79138
test_result = self.test_gen.validate_test(
80139
generated_test, self.args.run_tests_multiple_times
81140
)
@@ -84,11 +143,15 @@ def run(self):
84143
# Insert the test result into the database
85144
self.test_db.insert_attempt(test_result)
86145

146+
# Increment the iteration count
87147
iteration_count += 1
88148

149+
# Check if the desired coverage has been reached
89150
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
151+
# Run the coverage tool again if the desired coverage hasn't been reached
90152
self.test_gen.run_coverage()
91153

154+
# Log the final coverage
92155
if self.test_gen.current_coverage >= (self.test_gen.desired_coverage / 100):
93156
self.logger.info(
94157
f"Reached above target coverage of {self.test_gen.desired_coverage}% (Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%) in {iteration_count} iterations."
@@ -102,15 +165,18 @@ def run(self):
102165
else:
103166
self.logger.info(failure_message)
104167

105-
# Provide metric on total token usage
168+
# Provide metrics on total token usage
106169
self.logger.info(
107170
f"Total number of input tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_input_token_count}"
108171
)
109172
self.logger.info(
110173
f"Total number of output tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_output_token_count}"
111174
)
112175

113-
ReportGenerator.generate_report(test_results_list, self.args.report_filepath)
176+
# Generate a report
177+
# ReportGenerator.generate_report(test_results_list, self.args.report_filepath)
178+
self.test_db.dump_to_report(self.args.report_filepath)
114179

180+
# Finish the Weights & Biases run if it was initialized
115181
if "WANDB_API_KEY" in os.environ:
116182
wandb.finish()

cover_agent/ReportGenerator.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import difflib
12
from jinja2 import Template
23

34

45
class ReportGenerator:
5-
# Enhanced HTML template with additional styling
6+
# HTML template with fixed code formatting and dark background for the code block
67
HTML_TEMPLATE = """
78
<!DOCTYPE html>
89
<html lang="en">
@@ -39,10 +40,14 @@ class ReportGenerator:
3940
color: red;
4041
}
4142
pre {
42-
background-color: #000000 !important;
43+
background-color: #282c34 !important;
4344
color: #ffffff !important;
44-
padding: 5px;
45+
padding: 10px;
4546
border-radius: 5px;
47+
overflow-x: auto;
48+
white-space: pre-wrap;
49+
font-family: 'Courier New', Courier, monospace;
50+
font-size: 1.1em; /* Slightly larger font size */
4651
}
4752
</style>
4853
</head>
@@ -52,18 +57,31 @@ class ReportGenerator:
5257
<th>Status</th>
5358
<th>Reason</th>
5459
<th>Exit Code</th>
55-
<th>Stderr</th>
56-
<th>Stdout</th>
57-
<th>Test</th>
60+
<th>Language</th>
61+
<th>Modified Test File</th>
62+
<th>Details</th>
5863
</tr>
5964
{% for result in results %}
6065
<tr>
6166
<td class="status-{{ result.status }}">{{ result.status }}</td>
6267
<td>{{ result.reason }}</td>
6368
<td>{{ result.exit_code }}</td>
64-
<td>{% if result.stderr %}<pre><code class="language-shell">{{ result.stderr }}</code></pre>{% else %}&nbsp;{% endif %}</td>
65-
<td>{% if result.stdout %}<pre><code class="language-shell">{{ result.stdout }}</code></pre>{% else %}&nbsp;{% endif %}</td>
66-
<td>{% if result.test %}<pre><code class="language-python">{{ result.test }}</code></pre>{% else %}&nbsp;{% endif %}</td>
69+
<td>{{ result.language }}</td>
70+
<td>
71+
<details>
72+
<summary>View Full Code</summary>
73+
<pre><code>{{ result.full_diff | safe }}</code></pre>
74+
</details>
75+
</td>
76+
<td>
77+
<details>
78+
<summary>View More</summary>
79+
<div><strong>STDERR:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.stderr }}</code></pre></div>
80+
<div><strong>STDOUT:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.stdout }}</code></pre></div>
81+
<div><strong>Test Code:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.test_code }}</code></pre></div>
82+
<div><strong>Imports:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.imports }}</code></pre></div>
83+
</details>
84+
</td>
6785
</tr>
6886
{% endfor %}
6987
</table>
@@ -72,6 +90,61 @@ class ReportGenerator:
7290
</html>
7391
"""
7492

93+
@classmethod
94+
def generate_full_diff(cls, original, processed):
95+
"""
96+
Generates a full view of both the original and processed test files,
97+
highlighting added, removed, and unchanged lines, showing the full code.
98+
99+
:param original: String content of the original test file.
100+
:param processed: String content of the processed test file.
101+
:return: Full diff string formatted for HTML display, highlighting added, removed, and unchanged lines.
102+
"""
103+
diff = difflib.ndiff(original.splitlines(), processed.splitlines())
104+
105+
diff_html = []
106+
for line in diff:
107+
if line.startswith('+'):
108+
diff_html.append(f'<span class="diff-added">{line}</span>')
109+
elif line.startswith('-'):
110+
diff_html.append(f'<span class="diff-removed">{line}</span>')
111+
else:
112+
diff_html.append(f'<span class="diff-unchanged">{line}</span>')
113+
return '\n'.join(diff_html)
114+
115+
@classmethod
116+
def generate_partial_diff(cls, original, processed, context_lines=3):
117+
"""
118+
Generates a partial diff of both the original and processed test files,
119+
showing only added, removed, or changed lines with a few lines of context.
120+
121+
:param original: String content of the original test file.
122+
:param processed: String content of the processed test file.
123+
:param context_lines: Number of context lines to include around changes.
124+
:return: Partial diff string formatted for HTML display, highlighting added, removed, and unchanged lines with context.
125+
"""
126+
# Use unified_diff to generate a partial diff with context
127+
diff = difflib.unified_diff(
128+
original.splitlines(),
129+
processed.splitlines(),
130+
n=context_lines
131+
)
132+
133+
diff_html = []
134+
for line in diff:
135+
if line.startswith('+') and not line.startswith('+++'):
136+
diff_html.append(f'<span class="diff-added">{line}</span>')
137+
elif line.startswith('-') and not line.startswith('---'):
138+
diff_html.append(f'<span class="diff-removed">{line}</span>')
139+
elif line.startswith('@@'):
140+
# Highlight the diff context (line numbers)
141+
diff_html.append(f'<span class="diff-context">{line}</span>')
142+
else:
143+
# Show unchanged lines as context
144+
diff_html.append(f'<span class="diff-unchanged">{line}</span>')
145+
146+
return '\n'.join(diff_html)
147+
75148
@classmethod
76149
def generate_report(cls, results, file_path):
77150
"""
@@ -80,6 +153,10 @@ def generate_report(cls, results, file_path):
80153
:param results: List of dictionaries with test results.
81154
:param file_path: Path to the HTML file where the report will be written.
82155
"""
156+
# Generate the full diff for each result
157+
for result in results:
158+
result['full_diff'] = cls.generate_full_diff(result['original_test_file'], result['processed_test_file'])
159+
83160
template = Template(cls.HTML_TEMPLATE)
84161
html_content = template.render(results=results)
85162

0 commit comments

Comments
 (0)