Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cover_agent/coverage_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ def parse_missed_covered_lines_jacoco_csv(self, package_name: str, class_name: s

def extract_package_and_class_java(self):
package_pattern = re.compile(r"^\s*package\s+([\w\.]+)\s*;.*$")
class_pattern = re.compile(r"^\s*(?:public\s+)?(?:class|interface|record)\s+(\w+)(?:(?:<|\().*?(?:>|\)))?(?:\s+extends|\s+implements|\s*\{|$)")

class_pattern = re.compile(r"^\s*(?:(?:public|private|protected|static|final|abstract)\s+)*(?:class|interface|record)\s+(\w+)(?:(?:<|\().*?(?:>|\)|$))?(?:\s+extends|\s+implements|\s*\{|$)")

package_name = ""
class_name = ""
Expand Down
3 changes: 3 additions & 0 deletions cover_agent/unit_test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class UnitTestGenerationAttempt(Base):
source_file = Column(Text)
original_test_file = Column(Text)
processed_test_file = Column(Text)
error_summary = Column(Text)


class UnitTestDB:
Expand All @@ -52,6 +53,7 @@ def insert_attempt(self, test_result: dict):
source_file=test_result.get("source_file"),
original_test_file=test_result.get("original_test_file"),
processed_test_file=test_result.get("processed_test_file"),
error_summary=test_result.get("error_summary"),
)
session.add(new_attempt)
session.commit()
Expand Down Expand Up @@ -83,6 +85,7 @@ def get_all_attempts(self):
"source_file": attempt.source_file,
"original_test_file": attempt.original_test_file,
"processed_test_file": attempt.processed_test_file,
"error_summary": attempt.error_summary,
}
for attempt in attempts
]
Expand Down
46 changes: 31 additions & 15 deletions cover_agent/unit_test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,16 @@ def initial_test_suite_analysis(self):
relevant_line_number_to_insert_imports_after = None
counter_attempts = 0
while not relevant_line_number_to_insert_tests_after and counter_attempts < allowed_attempts:
test_file_content_numbered = "\n".join(
f"{i + 1} {line}"
for i, line in enumerate(
self._read_file(self.test_file_path).split("\n")
)
)
response, prompt_token_count, response_token_count, prompt = (
self.agent_completion.analyze_test_insert_line(
language=self.language,
test_file_numbered="\n".join(
f"{i + 1} {line}" for i, line in enumerate(self._read_file(self.test_file_path).split("\n"))
),
test_file_numbered=test_file_content_numbered,
additional_instructions_text=self.additional_instructions,
test_file_name=os.path.relpath(self.test_file_path, self.project_root),
)
Expand All @@ -242,6 +246,12 @@ def initial_test_suite_analysis(self):
relevant_line_number_to_insert_tests_after = tests_dict.get(
"relevant_line_number_to_insert_tests_after", None
)

if relevant_line_number_to_insert_tests_after:
file_len = len(test_file_content_numbered.splitlines())
if relevant_line_number_to_insert_tests_after == file_len:
relevant_line_number_to_insert_tests_after -= 1

relevant_line_number_to_insert_imports_after = tests_dict.get(
"relevant_line_number_to_insert_imports_after", None
)
Expand All @@ -257,6 +267,7 @@ def initial_test_suite_analysis(self):
f"Failed to analyze the relevant line number to insert new imports. tests_dict: {tests_dict}"
)


self.test_headers_indentation = test_headers_indentation
self.relevant_line_number_to_insert_tests_after = relevant_line_number_to_insert_tests_after
self.relevant_line_number_to_insert_imports_after = relevant_line_number_to_insert_imports_after
Expand Down Expand Up @@ -448,6 +459,15 @@ def validate_test(self, generated_test: dict):
with open(self.test_file_path, "w") as test_file:
test_file.write(original_content)
self.logger.info(f"Skipping a generated test that failed")

error_message = self.extract_error_message(processed_test, stdout, stderr)
if error_message:
logging.error(f"Error message summary:\n{error_message}")

self.failed_test_runs.append(
{"code": generated_test, "error_message": error_message}
) # Append failure details to the list

fail_details = {
"status": "FAIL",
"reason": "Test failed",
Expand All @@ -459,15 +479,10 @@ def validate_test(self, generated_test: dict):
"source_file": self.source_code,
"original_test_file": original_content,
"processed_test_file": processed_test,
"error_summary": error_message,
}

error_message = self.extract_error_message(fail_details)
if error_message:
logging.error(f"Error message summary:\n{error_message}")

self.failed_test_runs.append(
{"code": generated_test, "error_message": error_message}
) # Append failure details to the list

if "WANDB_API_KEY" in os.environ:
fail_details["error_message"] = error_message
Expand Down Expand Up @@ -618,7 +633,7 @@ def to_dict(self):
def to_json(self):
return json.dumps(self.to_dict())

def extract_error_message(self, fail_details):
def extract_error_message(self, processed_test_file, stdout, stderr):
"""
Extracts the error message from the provided fail details.

Expand All @@ -627,8 +642,9 @@ def extract_error_message(self, fail_details):
Logs errors encountered during the process.

Parameters:
fail_details (dict): Dictionary containing test failure details including stderr, stdout,
and processed test file contents.
processed_test_file: The test file under processing
stdout: Output message
stderr: Error message

Returns:
str: The error summary extracted from the response or an empty string if extraction fails.
Expand All @@ -638,9 +654,9 @@ def extract_error_message(self, fail_details):
response, prompt_token_count, response_token_count, prompt = self.agent_completion.analyze_test_failure(
source_file_name=os.path.relpath(self.source_file_path, self.project_root),
source_file=self._read_file(self.source_file_path),
processed_test_file=fail_details["processed_test_file"],
stderr=fail_details["stderr"],
stdout=fail_details["stdout"],
processed_test_file=processed_test_file,
stderr=stderr,
stdout=stdout,
test_file_name=os.path.relpath(self.test_file_path, self.project_root),
)
self.total_input_token_count += prompt_token_count
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.10
0.3.12
3 changes: 3 additions & 0 deletions tests/test_unit_test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_insert_attempt(self, unit_test_db):
"source_file": "sample source code",
"original_test_file": "sample test code",
"processed_test_file": "sample new test code",
"error_summary": "sample error summary",
}

# Insert the test result into the database
Expand All @@ -74,6 +75,7 @@ def test_insert_attempt(self, unit_test_db):
assert attempt.source_file == "sample source code"
assert attempt.original_test_file == "sample test code"
assert attempt.processed_test_file == "sample new test code"
assert attempt.error_summary == "sample error summary"

def test_dump_to_report(self, unit_test_db, tmp_path):
"""
Expand All @@ -94,6 +96,7 @@ def test_dump_to_report(self, unit_test_db, tmp_path):
"source_file": "sample source code",
"original_test_file": "sample test code",
"processed_test_file": "sample new test code",
"error_summary": "sample error summary",
}

# Insert the test result into the database
Expand Down
8 changes: 6 additions & 2 deletions tests/test_unit_test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def test_extract_error_message_exception_handling(self):
"stdout": "stdout content",
"processed_test_file": "",
}
error_message = generator.extract_error_message(fail_details)
error_message = generator.extract_error_message(fail_details["processed_test_file"],
fail_details["stdout"],
fail_details["stderr"])

# Should return an empty string on failure
assert error_message == ""
Expand Down Expand Up @@ -169,7 +171,9 @@ def test_extract_error_message_with_prompt_builder(self):
"source_file_name": temp_source_file.name,
"source_file": "",
}
error_message = generator.extract_error_message(fail_details)
error_message = generator.extract_error_message(fail_details["processed_test_file"],
fail_details["stdout"],
fail_details["stderr"])

assert error_message.strip() == "error_summary: Test failed due to assertion error in test_example"
mock_agent_completion_call_args = mock_agent_completion.analyze_test_failure.call_args[1]
Expand Down
Loading