Skip to content

Commit

Permalink
feat(autofix): Add semantic file search tool (#1894)
Browse files Browse the repository at this point in the history
Adds a semantic file search tool. Dumps all the file paths in a prompt
and asks the LLM to choose the best match.
Boost evals a lot: root cause up to 88%, fix up to 37%.
  • Loading branch information
roaga authored Feb 10, 2025
1 parent f5cc8c6 commit 0a9fe97
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/seer/automation/autofix/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,14 @@ def restart_step_with_user_response(
cur_state = state.get()
if memory:
tool_call_id = memory[-1].tool_call_id
tool_call_function = memory[-1].tool_call_function
if tool_call_id:
user_response = Message(role="tool", content=text, tool_call_id=tool_call_id)
user_response = Message(
role="tool",
content=text,
tool_call_id=tool_call_id,
tool_call_function=tool_call_function,
)
if memory[-1].role == "tool":
memory[-1] = user_response
else:
Expand Down
58 changes: 58 additions & 0 deletions src/seer/automation/autofix/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import textwrap

from langfuse.decorators import observe
from pydantic import BaseModel
from sentry_sdk.ai.monitoring import ai_track

from seer.automation.agent.client import GeminiProvider, LlmClient
Expand Down Expand Up @@ -43,6 +44,45 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.cleanup()

@observe(name="Semantic File Search")
@ai_track(description="Semantic File Search")
@inject
def semantic_file_search(
self, query: str, repo_name: str | None = None, llm_client: LlmClient = injected
):
repo_client = self.context.get_repo_client(repo_name=repo_name, type=self.repo_client_type)
repo_name = repo_client.repo_name
valid_file_paths = repo_client.get_valid_file_paths(files_only=True)

self.context.event_manager.add_log(f"Searching for {query}...")

class FilePath(BaseModel):
file_path: str

prompt = textwrap.dedent(
"""
I'm searching for the file in this codebase that contains {query}. Please pick the most relevant file from the following list:
{valid_file_paths}
"""
).format(query=query, valid_file_paths="\n".join(sorted(valid_file_paths)))

response = llm_client.generate_structured(
prompt=prompt,
model=GeminiProvider(model_name="gemini-2.0-flash-001"),
response_format=FilePath,
)
result = response.parsed
file_path = result.file_path if result else None
if file_path is None:
return "Could not figure out which file matches what you were looking for. You'll have to try yourself."

file_contents = self.context.get_file_contents(file_path, repo_name=repo_name)

if file_contents is None:
return "Could not figure out which file matches what you were looking for. You'll have to try yourself."

return f"This file might be what you're looking for: `{file_path}`. Contents:\n\n{file_contents}"

@observe(name="Expand Document")
@ai_track(description="Expand Document")
def expand_document(self, file_path: str, repo_name: str | None = None):
Expand Down Expand Up @@ -393,6 +433,24 @@ def get_tools(self):
],
required=["pattern"],
),
FunctionTool(
name="semantic_file_search",
fn=self.semantic_file_search,
description="Tries to find the file in the codebase that contains what you're looking for.",
parameters=[
{
"name": "query",
"type": "string",
"description": "Describe what file you're looking for.",
},
{
"name": "repo_name",
"type": "string",
"description": "Optional name of the repository to search in if you know it.",
},
],
required=["query"],
),
FunctionTool(
name="search_google",
fn=self.search_google,
Expand Down
4 changes: 3 additions & 1 deletion src/seer/automation/codebase/repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def get_file_content(
return None, "utf-8"

@functools.lru_cache(maxsize=8)
def get_valid_file_paths(self, sha: str | None = None) -> set[str]:
def get_valid_file_paths(self, sha: str | None = None, files_only=False) -> set[str]:
if sha is None:
sha = self.base_commit_sha

Expand All @@ -366,6 +366,8 @@ def get_valid_file_paths(self, sha: str | None = None) -> set[str]:
valid_file_paths: set[str] = set()

for file in tree.tree:
if files_only and "." not in file.path:
continue
valid_file_paths.add(file.path)

return valid_file_paths
Expand Down
83 changes: 83 additions & 0 deletions tests/automation/autofix/test_autofix_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,86 @@ def test_cleanup_not_called_when_tmp_dir_is_none(self):
tools.cleanup()

mock_cleanup_dir.assert_not_called()


class TestSemanticFileSearch:
def test_semantic_file_search_found(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.repo_name = "test_repo"
mock_repo_client.get_valid_file_paths.return_value = [
"src/file1.py",
"tests/test_file1.py",
"src/subfolder/file2.py",
]
autofix_tools.context.get_repo_client.return_value = mock_repo_client
autofix_tools.context.get_file_contents.return_value = "test file contents"

mock_llm_client = MagicMock()
mock_llm_client.generate_structured.return_value.parsed.file_path = "src/file1.py"

result = autofix_tools.semantic_file_search(
"find the main file", llm_client=mock_llm_client
)
assert (
result
== "This file might be what you're looking for: `src/file1.py`. Contents:\n\ntest file contents"
)

def test_semantic_file_search_not_found_no_file_path(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.repo_name = "test_repo"
mock_repo_client.get_valid_file_paths.return_value = [
"src/file1.py",
"tests/test_file1.py",
]
autofix_tools.context.get_repo_client.return_value = mock_repo_client

mock_llm_client = MagicMock()
mock_llm_client.generate_structured.return_value.parsed = None

result = autofix_tools.semantic_file_search(
"find nonexistent file", llm_client=mock_llm_client
)
assert (
result
== "Could not figure out which file matches what you were looking for. You'll have to try yourself."
)

def test_semantic_file_search_not_found_no_contents(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.repo_name = "test_repo"
mock_repo_client.get_valid_file_paths.return_value = [
"src/file1.py",
"tests/test_file1.py",
]
autofix_tools.context.get_repo_client.return_value = mock_repo_client
autofix_tools.context.get_file_contents.return_value = None

mock_llm_client = MagicMock()
mock_llm_client.generate_structured.return_value.parsed.file_path = "src/file1.py"

result = autofix_tools.semantic_file_search(
"find file with no contents", llm_client=mock_llm_client
)
assert (
result
== "Could not figure out which file matches what you were looking for. You'll have to try yourself."
)

def test_semantic_file_search_with_repo_name(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.repo_name = "specific_repo"
mock_repo_client.get_valid_file_paths.return_value = ["src/file1.py"]
autofix_tools.context.get_repo_client.return_value = mock_repo_client
autofix_tools.context.get_file_contents.return_value = "test file contents"
autofix_tools.repo_client_type = RepoClientType.READ

mock_llm_client = MagicMock()
mock_llm_client.generate_structured.return_value.parsed.file_path = "src/file1.py"

autofix_tools.semantic_file_search(
"find file", repo_name="specific_repo", llm_client=mock_llm_client
)
autofix_tools.context.get_repo_client.assert_called_once_with(
repo_name="specific_repo", type=RepoClientType.READ
)

0 comments on commit 0a9fe97

Please sign in to comment.