diff --git a/cover_agent/coverage_processor.py b/cover_agent/coverage_processor.py index beadb372..c00dae55 100644 --- a/cover_agent/coverage_processor.py +++ b/cover_agent/coverage_processor.py @@ -307,8 +307,7 @@ def parse_missed_covered_lines_jacoco_csv(self, package_name: str, class_name: s def extract_package_and_class_java(self): package_pattern = re.compile(r"^\s*package\s+([\w\.]+)\s*;.*$") - class_pattern = re.compile(r"^\s*(?:public\s+)?(?:class|interface|record)\s+(\w+)(?:(?:<|\().*?(?:>|\)))?(?:\s+extends|\s+implements|\s*\{|$)") - + class_pattern = re.compile(r"^\s*(?:(?:public|private|protected|static|final|abstract)\s+)*(?:class|interface|record)\s+(\w+)(?:(?:<|\().*?(?:>|\)|$))?(?:\s+extends|\s+implements|\s*\{|$)") package_name = "" class_name = "" diff --git a/cover_agent/unit_test_db.py b/cover_agent/unit_test_db.py index e8193a3e..c2947d21 100644 --- a/cover_agent/unit_test_db.py +++ b/cover_agent/unit_test_db.py @@ -28,6 +28,7 @@ class UnitTestGenerationAttempt(Base): source_file = Column(Text) original_test_file = Column(Text) processed_test_file = Column(Text) + error_summary = Column(Text) class UnitTestDB: @@ -52,6 +53,7 @@ def insert_attempt(self, test_result: dict): source_file=test_result.get("source_file"), original_test_file=test_result.get("original_test_file"), processed_test_file=test_result.get("processed_test_file"), + error_summary=test_result.get("error_summary"), ) session.add(new_attempt) session.commit() @@ -83,6 +85,7 @@ def get_all_attempts(self): "source_file": attempt.source_file, "original_test_file": attempt.original_test_file, "processed_test_file": attempt.processed_test_file, + "error_summary": attempt.error_summary, } for attempt in attempts ] diff --git a/cover_agent/unit_test_validator.py b/cover_agent/unit_test_validator.py index 6193e161..e71c80d5 100644 --- a/cover_agent/unit_test_validator.py +++ b/cover_agent/unit_test_validator.py @@ -225,12 +225,16 @@ def initial_test_suite_analysis(self): relevant_line_number_to_insert_imports_after = None counter_attempts = 0 while not relevant_line_number_to_insert_tests_after and counter_attempts < allowed_attempts: + test_file_content_numbered = "\n".join( + f"{i + 1} {line}" + for i, line in enumerate( + self._read_file(self.test_file_path).split("\n") + ) + ) response, prompt_token_count, response_token_count, prompt = ( self.agent_completion.analyze_test_insert_line( language=self.language, - test_file_numbered="\n".join( - f"{i + 1} {line}" for i, line in enumerate(self._read_file(self.test_file_path).split("\n")) - ), + test_file_numbered=test_file_content_numbered, additional_instructions_text=self.additional_instructions, test_file_name=os.path.relpath(self.test_file_path, self.project_root), ) @@ -242,6 +246,12 @@ def initial_test_suite_analysis(self): relevant_line_number_to_insert_tests_after = tests_dict.get( "relevant_line_number_to_insert_tests_after", None ) + + if relevant_line_number_to_insert_tests_after: + file_len = len(test_file_content_numbered.splitlines()) + if relevant_line_number_to_insert_tests_after == file_len: + relevant_line_number_to_insert_tests_after -= 1 + relevant_line_number_to_insert_imports_after = tests_dict.get( "relevant_line_number_to_insert_imports_after", None ) @@ -257,6 +267,7 @@ def initial_test_suite_analysis(self): f"Failed to analyze the relevant line number to insert new imports. tests_dict: {tests_dict}" ) + self.test_headers_indentation = test_headers_indentation self.relevant_line_number_to_insert_tests_after = relevant_line_number_to_insert_tests_after self.relevant_line_number_to_insert_imports_after = relevant_line_number_to_insert_imports_after @@ -448,6 +459,15 @@ def validate_test(self, generated_test: dict): with open(self.test_file_path, "w") as test_file: test_file.write(original_content) self.logger.info(f"Skipping a generated test that failed") + + error_message = self.extract_error_message(processed_test, stdout, stderr) + if error_message: + logging.error(f"Error message summary:\n{error_message}") + + self.failed_test_runs.append( + {"code": generated_test, "error_message": error_message} + ) # Append failure details to the list + fail_details = { "status": "FAIL", "reason": "Test failed", @@ -459,15 +479,10 @@ def validate_test(self, generated_test: dict): "source_file": self.source_code, "original_test_file": original_content, "processed_test_file": processed_test, + "error_summary": error_message, } - error_message = self.extract_error_message(fail_details) - if error_message: - logging.error(f"Error message summary:\n{error_message}") - self.failed_test_runs.append( - {"code": generated_test, "error_message": error_message} - ) # Append failure details to the list if "WANDB_API_KEY" in os.environ: fail_details["error_message"] = error_message @@ -618,7 +633,7 @@ def to_dict(self): def to_json(self): return json.dumps(self.to_dict()) - def extract_error_message(self, fail_details): + def extract_error_message(self, processed_test_file, stdout, stderr): """ Extracts the error message from the provided fail details. @@ -627,8 +642,9 @@ def extract_error_message(self, fail_details): Logs errors encountered during the process. Parameters: - fail_details (dict): Dictionary containing test failure details including stderr, stdout, - and processed test file contents. + processed_test_file: The test file under processing + stdout: Output message + stderr: Error message Returns: str: The error summary extracted from the response or an empty string if extraction fails. @@ -638,9 +654,9 @@ def extract_error_message(self, fail_details): response, prompt_token_count, response_token_count, prompt = self.agent_completion.analyze_test_failure( source_file_name=os.path.relpath(self.source_file_path, self.project_root), source_file=self._read_file(self.source_file_path), - processed_test_file=fail_details["processed_test_file"], - stderr=fail_details["stderr"], - stdout=fail_details["stdout"], + processed_test_file=processed_test_file, + stderr=stderr, + stdout=stdout, test_file_name=os.path.relpath(self.test_file_path, self.project_root), ) self.total_input_token_count += prompt_token_count diff --git a/cover_agent/version.txt b/cover_agent/version.txt index 5503126d..a487dd1d 100644 --- a/cover_agent/version.txt +++ b/cover_agent/version.txt @@ -1 +1 @@ -0.3.10 +0.3.12 \ No newline at end of file diff --git a/tests/test_unit_test_db.py b/tests/test_unit_test_db.py index 61304f4c..e25164de 100644 --- a/tests/test_unit_test_db.py +++ b/tests/test_unit_test_db.py @@ -54,6 +54,7 @@ def test_insert_attempt(self, unit_test_db): "source_file": "sample source code", "original_test_file": "sample test code", "processed_test_file": "sample new test code", + "error_summary": "sample error summary", } # Insert the test result into the database @@ -74,6 +75,7 @@ def test_insert_attempt(self, unit_test_db): assert attempt.source_file == "sample source code" assert attempt.original_test_file == "sample test code" assert attempt.processed_test_file == "sample new test code" + assert attempt.error_summary == "sample error summary" def test_dump_to_report(self, unit_test_db, tmp_path): """ @@ -94,6 +96,7 @@ def test_dump_to_report(self, unit_test_db, tmp_path): "source_file": "sample source code", "original_test_file": "sample test code", "processed_test_file": "sample new test code", + "error_summary": "sample error summary", } # Insert the test result into the database diff --git a/tests/test_unit_test_validator.py b/tests/test_unit_test_validator.py index 64c72dea..fd1e7693 100644 --- a/tests/test_unit_test_validator.py +++ b/tests/test_unit_test_validator.py @@ -59,7 +59,9 @@ def test_extract_error_message_exception_handling(self): "stdout": "stdout content", "processed_test_file": "", } - error_message = generator.extract_error_message(fail_details) + error_message = generator.extract_error_message(fail_details["processed_test_file"], + fail_details["stdout"], + fail_details["stderr"]) # Should return an empty string on failure assert error_message == "" @@ -169,7 +171,9 @@ def test_extract_error_message_with_prompt_builder(self): "source_file_name": temp_source_file.name, "source_file": "", } - error_message = generator.extract_error_message(fail_details) + error_message = generator.extract_error_message(fail_details["processed_test_file"], + fail_details["stdout"], + fail_details["stderr"]) assert error_message.strip() == "error_summary: Test failed due to assertion error in test_example" mock_agent_completion_call_args = mock_agent_completion.analyze_test_failure.call_args[1]