diff --git a/src/seer/automation/autofix/tasks.py b/src/seer/automation/autofix/tasks.py index d66269cac..9e570b541 100644 --- a/src/seer/automation/autofix/tasks.py +++ b/src/seer/automation/autofix/tasks.py @@ -301,8 +301,14 @@ def restart_step_with_user_response( cur_state = state.get() if memory: tool_call_id = memory[-1].tool_call_id + tool_call_function = memory[-1].tool_call_function if tool_call_id: - user_response = Message(role="tool", content=text, tool_call_id=tool_call_id) + user_response = Message( + role="tool", + content=text, + tool_call_id=tool_call_id, + tool_call_function=tool_call_function, + ) if memory[-1].role == "tool": memory[-1] = user_response else: diff --git a/src/seer/automation/autofix/tools.py b/src/seer/automation/autofix/tools.py index 4c5e2364a..01cc5afe9 100644 --- a/src/seer/automation/autofix/tools.py +++ b/src/seer/automation/autofix/tools.py @@ -4,6 +4,7 @@ import textwrap from langfuse.decorators import observe +from pydantic import BaseModel from sentry_sdk.ai.monitoring import ai_track from seer.automation.agent.client import GeminiProvider, LlmClient @@ -43,6 +44,45 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.cleanup() + @observe(name="Semantic File Search") + @ai_track(description="Semantic File Search") + @inject + def semantic_file_search( + self, query: str, repo_name: str | None = None, llm_client: LlmClient = injected + ): + repo_client = self.context.get_repo_client(repo_name=repo_name, type=self.repo_client_type) + repo_name = repo_client.repo_name + valid_file_paths = repo_client.get_valid_file_paths(files_only=True) + + self.context.event_manager.add_log(f"Searching for {query}...") + + class FilePath(BaseModel): + file_path: str + + prompt = textwrap.dedent( + """ + I'm searching for the file in this codebase that contains {query}. Please pick the most relevant file from the following list: + {valid_file_paths} + """ + ).format(query=query, valid_file_paths="\n".join(sorted(valid_file_paths))) + + response = llm_client.generate_structured( + prompt=prompt, + model=GeminiProvider(model_name="gemini-2.0-flash-001"), + response_format=FilePath, + ) + result = response.parsed + file_path = result.file_path if result else None + if file_path is None: + return "Could not figure out which file matches what you were looking for. You'll have to try yourself." + + file_contents = self.context.get_file_contents(file_path, repo_name=repo_name) + + if file_contents is None: + return "Could not figure out which file matches what you were looking for. You'll have to try yourself." + + return f"This file might be what you're looking for: `{file_path}`. Contents:\n\n{file_contents}" + @observe(name="Expand Document") @ai_track(description="Expand Document") def expand_document(self, file_path: str, repo_name: str | None = None): @@ -393,6 +433,24 @@ def get_tools(self): ], required=["pattern"], ), + FunctionTool( + name="semantic_file_search", + fn=self.semantic_file_search, + description="Tries to find the file in the codebase that contains what you're looking for.", + parameters=[ + { + "name": "query", + "type": "string", + "description": "Describe what file you're looking for.", + }, + { + "name": "repo_name", + "type": "string", + "description": "Optional name of the repository to search in if you know it.", + }, + ], + required=["query"], + ), FunctionTool( name="search_google", fn=self.search_google, diff --git a/src/seer/automation/codebase/repo_client.py b/src/seer/automation/codebase/repo_client.py index 809a159ae..d7c09dfee 100644 --- a/src/seer/automation/codebase/repo_client.py +++ b/src/seer/automation/codebase/repo_client.py @@ -352,7 +352,7 @@ def get_file_content( return None, "utf-8" @functools.lru_cache(maxsize=8) - def get_valid_file_paths(self, sha: str | None = None) -> set[str]: + def get_valid_file_paths(self, sha: str | None = None, files_only=False) -> set[str]: if sha is None: sha = self.base_commit_sha @@ -366,6 +366,8 @@ def get_valid_file_paths(self, sha: str | None = None) -> set[str]: valid_file_paths: set[str] = set() for file in tree.tree: + if files_only and "." not in file.path: + continue valid_file_paths.add(file.path) return valid_file_paths diff --git a/tests/automation/autofix/test_autofix_tools.py b/tests/automation/autofix/test_autofix_tools.py index 977d6437b..449683dcc 100644 --- a/tests/automation/autofix/test_autofix_tools.py +++ b/tests/automation/autofix/test_autofix_tools.py @@ -120,3 +120,86 @@ def test_cleanup_not_called_when_tmp_dir_is_none(self): tools.cleanup() mock_cleanup_dir.assert_not_called() + + +class TestSemanticFileSearch: + def test_semantic_file_search_found(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.repo_name = "test_repo" + mock_repo_client.get_valid_file_paths.return_value = [ + "src/file1.py", + "tests/test_file1.py", + "src/subfolder/file2.py", + ] + autofix_tools.context.get_repo_client.return_value = mock_repo_client + autofix_tools.context.get_file_contents.return_value = "test file contents" + + mock_llm_client = MagicMock() + mock_llm_client.generate_structured.return_value.parsed.file_path = "src/file1.py" + + result = autofix_tools.semantic_file_search( + "find the main file", llm_client=mock_llm_client + ) + assert ( + result + == "This file might be what you're looking for: `src/file1.py`. Contents:\n\ntest file contents" + ) + + def test_semantic_file_search_not_found_no_file_path(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.repo_name = "test_repo" + mock_repo_client.get_valid_file_paths.return_value = [ + "src/file1.py", + "tests/test_file1.py", + ] + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + mock_llm_client = MagicMock() + mock_llm_client.generate_structured.return_value.parsed = None + + result = autofix_tools.semantic_file_search( + "find nonexistent file", llm_client=mock_llm_client + ) + assert ( + result + == "Could not figure out which file matches what you were looking for. You'll have to try yourself." + ) + + def test_semantic_file_search_not_found_no_contents(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.repo_name = "test_repo" + mock_repo_client.get_valid_file_paths.return_value = [ + "src/file1.py", + "tests/test_file1.py", + ] + autofix_tools.context.get_repo_client.return_value = mock_repo_client + autofix_tools.context.get_file_contents.return_value = None + + mock_llm_client = MagicMock() + mock_llm_client.generate_structured.return_value.parsed.file_path = "src/file1.py" + + result = autofix_tools.semantic_file_search( + "find file with no contents", llm_client=mock_llm_client + ) + assert ( + result + == "Could not figure out which file matches what you were looking for. You'll have to try yourself." + ) + + def test_semantic_file_search_with_repo_name(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.repo_name = "specific_repo" + mock_repo_client.get_valid_file_paths.return_value = ["src/file1.py"] + autofix_tools.context.get_repo_client.return_value = mock_repo_client + autofix_tools.context.get_file_contents.return_value = "test file contents" + autofix_tools.repo_client_type = RepoClientType.READ + + mock_llm_client = MagicMock() + mock_llm_client.generate_structured.return_value.parsed.file_path = "src/file1.py" + + autofix_tools.semantic_file_search( + "find file", repo_name="specific_repo", llm_client=mock_llm_client + ) + autofix_tools.context.get_repo_client.assert_called_once_with( + repo_name="specific_repo", type=RepoClientType.READ + )