Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(autofix): Use an LLM to apply edits to full files in the coding step #1895

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def generate_text(
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
predicted_output: str | None = None,
):
message_dicts, tool_dicts = self._prep_message_and_tools(
messages=messages,
Expand All @@ -129,6 +130,14 @@ def generate_text(
),
max_tokens=max_tokens or openai.NotGiven(),
timeout=timeout or openai.NotGiven(),
prediction=(
{
"type": "content",
"content": predicted_output,
}
if predicted_output
else openai.NotGiven()
),
)

openai_message = completion.choices[0].message
Expand Down Expand Up @@ -1107,6 +1116,7 @@ def generate_text(
max_tokens: int | None = None,
run_name: str | None = None,
timeout: float | None = None,
predicted_output: str | None = None,
) -> LlmGenerateTextResponse:
try:
if run_name:
Expand All @@ -1129,6 +1139,7 @@ def generate_text(
temperature=temperature or default_temperature,
tools=tools,
timeout=timeout,
predicted_output=predicted_output,
)
elif model.provider_name == LlmProviderType.ANTHROPIC:
model = cast(AnthropicProvider, model)
Expand Down
252 changes: 167 additions & 85 deletions src/seer/automation/autofix/components/coding/component.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import concurrent.futures
import difflib
import logging
import textwrap

import sentry_sdk
from langfuse.decorators import observe
from pydantic import BaseModel
from sentry_sdk.ai.monitoring import ai_track

from seer.automation.agent.agent import AgentConfig, RunConfig
from seer.automation.agent.client import AnthropicProvider, GeminiProvider, LlmClient
from seer.automation.agent.client import (
AnthropicProvider,
GeminiProvider,
LlmClient,
OpenAiProvider,
)
from seer.automation.agent.models import Message
from seer.automation.autofix.autofix_agent import AutofixAgent
from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.components.coding.models import (
CodeChangesPromptXml,
CodingOutput,
CodingRequest,
FileMissingObj,
PlanStepsPromptXml,
PlanTaskPromptXml,
)
from seer.automation.autofix.components.coding.prompts import CodingPrompts
from seer.automation.autofix.components.coding.utils import (
Expand All @@ -26,7 +35,6 @@
from seer.automation.models import FileChange
from seer.automation.utils import escape_multi_xml, extract_text_inside_tags
from seer.dependency_injection import inject, injected
from seer.langfuse import append_langfuse_observation_metadata, append_langfuse_trace_tags

logger = logging.getLogger(__name__)

Expand All @@ -38,53 +46,6 @@ def _append_file_change(self, repo_external_id: str, file_change: FileChange):
with self.context.state.update() as cur:
cur.codebases[repo_external_id].file_changes.append(file_change)

@observe(name="Incorrect diff fixer")
@ai_track(description="Incorrect diff fixer")
@inject
def _handle_missing_file_changes(
self,
missing_changes_by_file: dict[str, FileMissingObj],
llm_client: LlmClient = injected,
):
for file_path, file_missing_obj in missing_changes_by_file.items():
new_response = llm_client.generate_text(
model=AnthropicProvider.model("claude-3-5-sonnet-v2@20241022"),
prompt=CodingPrompts.format_incorrect_diff_fixer(
file_path,
file_missing_obj.diff_chunks,
file_missing_obj.file_content,
),
temperature=0.0,
run_name="Incorrect Diff Fixer",
)

with self.context.state.update() as cur:
cur.usage += new_response.metadata.usage

if not new_response.message.content:
continue

corrected_diffs = extract_text_inside_tags(
new_response.message.content, "corrected_diffs"
)
new_task = file_missing_obj.task.model_copy(update={"diff": corrected_diffs})

changes, missing_changes = task_to_file_change(new_task, file_missing_obj.file_content)

# If there are any more missing changes, we just ignore at this point
missing_changes_count = len(missing_changes)
append_langfuse_observation_metadata(
{
"missing_changes_count": missing_changes_count,
}
)
if missing_changes_count > 0:
append_langfuse_trace_tags([f"missing_changes_count:{missing_changes_count}"])

repo_client = self.context.get_repo_client(new_task.repo_name)
for change in changes:
self._append_file_change(repo_client.repo_external_id, change)

def _prefill_initial_memory(self) -> list[Message]:
memory: list[Message] = []

Expand All @@ -97,6 +58,84 @@ def _prefill_initial_memory(self) -> list[Message]:

return memory

@observe(name="Apply Code")
@ai_track(description="Apply Code")
@inject
def apply_code_suggestion_to_file(
self,
original_content: str | None,
new_content: str,
file_path: str,
llm_client: LlmClient = injected,
) -> str | None:
if original_content is None:
# For new files, create a simple diff showing the entire file as added
lines = new_content.splitlines()
diff = f"--- /dev/null\n+++ b/{file_path}\n"
diff += f"@@ -0,0 +1,{len(lines)} @@\n"
for line in lines:
diff += f"+{line}\n"
return diff

# For existing files, use LLM to merge changes
system_prompt = """You are an coding assistant that helps merge code updates, ensuring every modification is fully integrated."""
prompt = textwrap.dedent(
"""
Merge all changes from the <update> snippet into the <code> below.
- Preserve the code's structure, order, comments, and indentation exactly.
- Output only the updated code, enclosed within <updated_code> and </updated_code> tags.
- Do not include any other text, explanations, placeholders, ellipses, or code fences.

<code>{original_content}</code>

<update>{new_content}</update>

Provide the complete updated code.
"""
).format(original_content=original_content, new_content=new_content)

# use predicted output for faster response
predicted_output = f"<updated_code>{original_content}</updated_code>"
try:
output = llm_client.generate_text(
system_prompt=system_prompt,
messages=[Message(role="user", content=prompt)],
model=OpenAiProvider.model("gpt-4o-mini"),
predicted_output=predicted_output,
)
except Exception as e:
if e.code == 400: # too much content, fallback to model with bigger input/output limit
sentry_sdk.capture_message(
f"Failed to apply code suggestion to file with gpt-4o-mini, falling back to o3-mini. Error message: {str(e)}"
)
try:
output = llm_client.generate_text(
system_prompt=system_prompt,
messages=[Message(role="user", content=prompt)],
model=OpenAiProvider.model("o3-mini"),
)
except Exception as e2:
sentry_sdk.capture_exception(e2)
return None
else:
raise e

text = output.message.content
updated_content = extract_text_inside_tags(text, "updated_code")

# Generate unified diff between original_content and updated_content
original_lines = original_content.splitlines()
updated_lines = updated_content.splitlines()
diff = "\n".join(
difflib.unified_diff(
original_lines,
updated_lines,
fromfile=f"a/{file_path}",
tofile=f"b/{file_path}",
)
)
return diff

@observe(name="Is Obvious")
@ai_track(description="Is Obvious")
@inject
Expand Down Expand Up @@ -153,6 +192,10 @@ class NeedToSearchCodebaseOutput(BaseModel):

return False

@inject
def _get_llm_client(self, llm_client: LlmClient = injected) -> LlmClient:
return llm_client

@observe(name="Code")
@ai_track(description="Code")
def invoke(self, request: CodingRequest) -> CodingOutput | None:
Expand Down Expand Up @@ -197,16 +240,15 @@ def invoke(self, request: CodingRequest) -> CodingOutput | None:
if not response:
return None

plan_steps_content = extract_text_inside_tags(response, "plan_steps")

coding_output = PlanStepsPromptXml.from_xml(
f"<plan_steps>{escape_multi_xml(plan_steps_content, ['diff', 'description', 'commit_message'])}</plan_steps>"
code_changes_content = extract_text_inside_tags(response, "code_changes")
code_changes_output = CodeChangesPromptXml.from_xml(
f"<code_changes>{escape_multi_xml(code_changes_content, ['code', 'commit_message'])}</code_changes>"
).to_model()

# We only do this once, if it still errors, we just let it go
missing_files_errors = []
file_exist_errors = []
for task in coding_output.tasks:
for task in code_changes_output.tasks:
repo_client = self.context.get_repo_client(task.repo_name)
file_content, _ = repo_client.get_file_content(task.file_path, autocorrect=True)
if task.type == "file_change" and not file_content:
Expand All @@ -229,48 +271,88 @@ def invoke(self, request: CodingRequest) -> CodingOutput | None:
),
)

if new_response and "<plan_steps>" in new_response:
coding_output = PlanStepsPromptXml.from_xml(
f"<plan_steps>{escape_multi_xml(extract_text_inside_tags(new_response, 'plan_steps'), ['diff', 'description', 'commit_message'])}</plan_steps>"
if new_response and "<code_changes>" in new_response:
code_changes_output = CodeChangesPromptXml.from_xml(
f"<code_changes>{escape_multi_xml(extract_text_inside_tags(new_response, 'code_changes'), ['code', 'commit_message'])}</code_changes>"
).to_model()

missing_changes_by_file: dict[str, FileMissingObj] = dict()
self.context.event_manager.add_log("Applying changes...")
tasks_with_diffs: list[PlanTaskPromptXml] = []

for task in coding_output.tasks:
# Resolve LlmClient once in the main thread
resolved_llm_client = self._get_llm_client()

def process_task(task, llm_client):
repo_client = self.context.get_repo_client(task.repo_name)
if task.type == "file_change":
file_content, _ = repo_client.get_file_content(task.file_path, autocorrect=True)

if not file_content:
logger.warning(f"Failed to get content for {task.file_path}")
continue

changes, missing_changes = task_to_file_change(task, file_content)

for change in changes:
self._append_file_change(repo_client.repo_external_id, change)
return None
diff = self.apply_code_suggestion_to_file(
file_content, task.code, task.file_path, llm_client=llm_client
)
if not diff:
return None
task_with_diff = PlanTaskPromptXml(
file_path=task.file_path,
repo_name=task.repo_name,
type="file_change",
diff=diff,
commit_message=task.commit_message,
description=f"Change file {task.file_path}",
)
changes, _ = task_to_file_change(task_with_diff, file_content)
updates = [(repo_client.repo_external_id, change) for change in changes]
return task_with_diff, updates

if missing_changes:
missing_changes_by_file[task.file_path] = FileMissingObj(
file_path=task.file_path,
file_content=file_content,
diff_chunks=missing_changes,
task=task,
)
elif task.type == "file_delete":
self._append_file_change(
repo_client.repo_external_id,
task_to_file_delete(task),
task_with_diff = PlanTaskPromptXml(
file_path=task.file_path,
repo_name=task.repo_name,
type="file_delete",
commit_message=task.commit_message,
diff="",
description=f"Delete file {task.file_path}",
)
update = (repo_client.repo_external_id, task_to_file_delete(task_with_diff))
return task_with_diff, [update]

elif task.type == "file_create":
self._append_file_change(
repo_client.repo_external_id,
task_to_file_create(task),
diff = self.apply_code_suggestion_to_file(
None, task.code, task.file_path, llm_client=llm_client
)
if not diff:
return None
task_with_diff = PlanTaskPromptXml(
file_path=task.file_path,
repo_name=task.repo_name,
type="file_create",
diff=diff,
commit_message=task.commit_message,
description=f"Create file {task.file_path}",
)
update = (repo_client.repo_external_id, task_to_file_create(task_with_diff))
return task_with_diff, [update]

else:
logger.warning(f"Unsupported task type: {task.type}")
return None

if missing_changes_by_file:
self._handle_missing_file_changes(missing_changes_by_file)

return coding_output
# apply change tasks in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(process_task, task, resolved_llm_client)
for task in code_changes_output.tasks
]
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result:
task_with_diff, updates = result
tasks_with_diffs.append(task_with_diff)
for repo_external_id, file_change in updates:
self._append_file_change(repo_external_id, file_change)
else:
sentry_sdk.capture_message("Failed to apply code changes.")

return CodingOutput(tasks=tasks_with_diffs)
Loading
Loading