Skip to content
7 changes: 4 additions & 3 deletions cover_agent/ai_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from cover_agent.custom_logger import CustomLogger
from cover_agent.record_replay_manager import RecordReplayManager
from cover_agent.settings.config_loader import get_settings
from cover_agent.utils import get_original_caller
from cover_agent.utils import get_original_caller, contains_any_substring


def conditional_retry(func):
Expand Down Expand Up @@ -109,7 +109,8 @@ def call_model(self, prompt: dict, stream=True):
}

# Model-specific adjustments
if self.model in ["o1-preview", "o1-mini", "o1", "o3-mini"]:
model_names = ["o1-preview", "o1-mini", "o1", "o3-mini", "o4-mini"]
if contains_any_substring(self.model, model_names):
stream = False # o1 doesn't support streaming
completion_params["temperature"] = 1
completion_params["stream"] = False # o1 doesn't support streaming
Expand Down Expand Up @@ -185,4 +186,4 @@ def call_model(self, prompt: dict, stream=True):
)

# Returns: Response, Prompt token count, and Completion token count
return content, prompt_tokens, completion_tokens
return content, prompt_tokens, completion_tokens
3 changes: 1 addition & 2 deletions cover_agent/coverage_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ def parse_missed_covered_lines_jacoco_csv(self, package_name: str, class_name: s

def extract_package_and_class_java(self):
package_pattern = re.compile(r"^\s*package\s+([\w\.]+)\s*;.*$")
class_pattern = re.compile(r"^\s*(?:public\s+)?(?:class|interface|record)\s+(\w+)(?:(?:<|\().*?(?:>|\)))?(?:\s+extends|\s+implements|\s*\{|$)")

class_pattern = re.compile(r"^\s*(?:(?:public|private|protected|static|final|abstract)\s+)*(?:class|interface|record)\s+(\w+)(?:(?:<|\().*?(?:>|\)|$))?(?:\s+extends|\s+implements|\s*\{|$)")

package_name = ""
class_name = ""
Expand Down
3 changes: 3 additions & 0 deletions cover_agent/unit_test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class UnitTestGenerationAttempt(Base):
source_file = Column(Text)
original_test_file = Column(Text)
processed_test_file = Column(Text)
error_summary = Column(Text)


class UnitTestDB:
Expand All @@ -52,6 +53,7 @@ def insert_attempt(self, test_result: dict):
source_file=test_result.get("source_file"),
original_test_file=test_result.get("original_test_file"),
processed_test_file=test_result.get("processed_test_file"),
error_summary=test_result.get("error_summary"),
)
session.add(new_attempt)
session.commit()
Expand Down Expand Up @@ -83,6 +85,7 @@ def get_all_attempts(self):
"source_file": attempt.source_file,
"original_test_file": attempt.original_test_file,
"processed_test_file": attempt.processed_test_file,
"error_summary": attempt.error_summary,
}
for attempt in attempts
]
Expand Down
4 changes: 2 additions & 2 deletions cover_agent/unit_test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cover_agent.custom_logger import CustomLogger
from cover_agent.file_preprocessor import FilePreprocessor
from cover_agent.settings.config_loader import get_settings
from cover_agent.utils import load_yaml
from cover_agent.utils import load_yaml, get_included_files


class UnitTestGenerator:
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
self.code_coverage_report_path = code_coverage_report_path
self.test_command = test_command
self.test_command_dir = test_command_dir
self.included_files = included_files
self.included_files = get_included_files(included_files, project_root, True)
self.coverage_type = coverage_type
self.additional_instructions = additional_instructions
self.language = self.get_code_language(source_file_path)
Expand Down
78 changes: 33 additions & 45 deletions cover_agent/unit_test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cover_agent.runner import Runner
from cover_agent.settings.config_loader import get_settings
from cover_agent.settings.config_schema import CoverageType
from cover_agent.utils import load_yaml
from cover_agent.utils import load_yaml, get_included_files


class UnitTestValidator:
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self.code_coverage_report_path = code_coverage_report_path
self.test_command = test_command
self.test_command_dir = test_command_dir
self.included_files = self.get_included_files(included_files)
self.included_files = get_included_files(included_files, project_root, True)
self.coverage_type = coverage_type
self.desired_coverage = desired_coverage
self.additional_instructions = additional_instructions
Expand Down Expand Up @@ -225,12 +225,16 @@ def initial_test_suite_analysis(self):
relevant_line_number_to_insert_imports_after = None
counter_attempts = 0
while not relevant_line_number_to_insert_tests_after and counter_attempts < allowed_attempts:
test_file_content_numbered = "\n".join(
f"{i + 1} {line}"
for i, line in enumerate(
self._read_file(self.test_file_path).split("\n")
)
)
response, prompt_token_count, response_token_count, prompt = (
self.agent_completion.analyze_test_insert_line(
language=self.language,
test_file_numbered="\n".join(
f"{i + 1} {line}" for i, line in enumerate(self._read_file(self.test_file_path).split("\n"))
),
test_file_numbered=test_file_content_numbered,
additional_instructions_text=self.additional_instructions,
test_file_name=os.path.relpath(self.test_file_path, self.project_root),
)
Expand All @@ -242,6 +246,12 @@ def initial_test_suite_analysis(self):
relevant_line_number_to_insert_tests_after = tests_dict.get(
"relevant_line_number_to_insert_tests_after", None
)

if relevant_line_number_to_insert_tests_after:
file_len = len(test_file_content_numbered.splitlines())
if relevant_line_number_to_insert_tests_after == file_len:
relevant_line_number_to_insert_tests_after -= 1

relevant_line_number_to_insert_imports_after = tests_dict.get(
"relevant_line_number_to_insert_imports_after", None
)
Expand All @@ -257,6 +267,7 @@ def initial_test_suite_analysis(self):
f"Failed to analyze the relevant line number to insert new imports. tests_dict: {tests_dict}"
)


self.test_headers_indentation = test_headers_indentation
self.relevant_line_number_to_insert_tests_after = relevant_line_number_to_insert_tests_after
self.relevant_line_number_to_insert_imports_after = relevant_line_number_to_insert_imports_after
Expand Down Expand Up @@ -306,34 +317,6 @@ def run_coverage(self):
with open(self.code_coverage_report_path, "r") as f:
self.code_coverage_report = f.read()

@staticmethod
def get_included_files(included_files):
"""
A method to read and concatenate the contents of included files into a single string.

Parameters:
included_files (list): A list of paths to included files.

Returns:
str: A string containing the concatenated contents of the included files, or an empty string if the input list is empty.
"""
if included_files:
included_files_content = []
file_names = []
for file_path in included_files:
try:
with open(file_path, "r") as file:
included_files_content.append(file.read())
file_names.append(file_path)
except IOError as e:
print(f"Error reading file {file_path}: {str(e)}")
out_str = ""
if included_files_content:
for i, content in enumerate(included_files_content):
out_str += f"file_path: `{file_names[i]}`\ncontent:\n```\n{content}\n```\n"

return out_str.strip()
return ""

def validate_test(self, generated_test: dict):
"""
Expand Down Expand Up @@ -448,6 +431,15 @@ def validate_test(self, generated_test: dict):
with open(self.test_file_path, "w") as test_file:
test_file.write(original_content)
self.logger.info(f"Skipping a generated test that failed")

error_message = self.extract_error_message(processed_test, stdout, stderr)
if error_message:
logging.error(f"Error message summary:\n{error_message}")

self.failed_test_runs.append(
{"code": generated_test, "error_message": error_message}
) # Append failure details to the list

fail_details = {
"status": "FAIL",
"reason": "Test failed",
Expand All @@ -459,15 +451,10 @@ def validate_test(self, generated_test: dict):
"source_file": self.source_code,
"original_test_file": original_content,
"processed_test_file": processed_test,
"error_summary": error_message,
}

error_message = self.extract_error_message(fail_details)
if error_message:
logging.error(f"Error message summary:\n{error_message}")

self.failed_test_runs.append(
{"code": generated_test, "error_message": error_message}
) # Append failure details to the list

if "WANDB_API_KEY" in os.environ:
fail_details["error_message"] = error_message
Expand Down Expand Up @@ -618,7 +605,7 @@ def to_dict(self):
def to_json(self):
return json.dumps(self.to_dict())

def extract_error_message(self, fail_details):
def extract_error_message(self, processed_test_file, stdout, stderr):
"""
Extracts the error message from the provided fail details.

Expand All @@ -627,8 +614,9 @@ def extract_error_message(self, fail_details):
Logs errors encountered during the process.

Parameters:
fail_details (dict): Dictionary containing test failure details including stderr, stdout,
and processed test file contents.
processed_test_file: The test file under processing
stdout: Output message
stderr: Error message

Returns:
str: The error summary extracted from the response or an empty string if extraction fails.
Expand All @@ -638,9 +626,9 @@ def extract_error_message(self, fail_details):
response, prompt_token_count, response_token_count, prompt = self.agent_completion.analyze_test_failure(
source_file_name=os.path.relpath(self.source_file_path, self.project_root),
source_file=self._read_file(self.source_file_path),
processed_test_file=fail_details["processed_test_file"],
stderr=fail_details["stderr"],
stdout=fail_details["stdout"],
processed_test_file=processed_test_file,
stderr=stderr,
stdout=stdout,
test_file_name=os.path.relpath(self.test_file_path, self.project_root),
)
self.total_input_token_count += prompt_token_count
Expand Down
13 changes: 13 additions & 0 deletions cover_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,16 @@ def truncate_hash(hash_value: str, hash_display_length: int) -> str:
truncate_hash("abcdef123456", 6) # Returns "abcdef"
"""
return hash_value[:hash_display_length]

def contains_any_substring(main_string, substrings):
"""
Checks if any of the substrings exist in the main string.

Args:
main_string: The string to search within.
substrings: A list or tuple of substrings to search for.

Returns:
True if any substring is found, False otherwise.
"""
return any(sub in main_string for sub in substrings)
2 changes: 1 addition & 1 deletion cover_agent/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.10
0.3.12
3 changes: 3 additions & 0 deletions tests/test_unit_test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_insert_attempt(self, unit_test_db):
"source_file": "sample source code",
"original_test_file": "sample test code",
"processed_test_file": "sample new test code",
"error_summary": "sample error summary",
}

# Insert the test result into the database
Expand All @@ -74,6 +75,7 @@ def test_insert_attempt(self, unit_test_db):
assert attempt.source_file == "sample source code"
assert attempt.original_test_file == "sample test code"
assert attempt.processed_test_file == "sample new test code"
assert attempt.error_summary == "sample error summary"

def test_dump_to_report(self, unit_test_db, tmp_path):
"""
Expand All @@ -94,6 +96,7 @@ def test_dump_to_report(self, unit_test_db, tmp_path):
"source_file": "sample source code",
"original_test_file": "sample test code",
"processed_test_file": "sample new test code",
"error_summary": "sample error summary",
}

# Insert the test result into the database
Expand Down
8 changes: 6 additions & 2 deletions tests/test_unit_test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def test_extract_error_message_exception_handling(self):
"stdout": "stdout content",
"processed_test_file": "",
}
error_message = generator.extract_error_message(fail_details)
error_message = generator.extract_error_message(fail_details["processed_test_file"],
fail_details["stdout"],
fail_details["stderr"])

# Should return an empty string on failure
assert error_message == ""
Expand Down Expand Up @@ -169,7 +171,9 @@ def test_extract_error_message_with_prompt_builder(self):
"source_file_name": temp_source_file.name,
"source_file": "",
}
error_message = generator.extract_error_message(fail_details)
error_message = generator.extract_error_message(fail_details["processed_test_file"],
fail_details["stdout"],
fail_details["stderr"])

assert error_message.strip() == "error_summary: Test failed due to assertion error in test_example"
mock_agent_completion_call_args = mock_agent_completion.analyze_test_failure.call_args[1]
Expand Down
Loading