Skip to content

Commit

Permalink
feat(autofix): Endpoint to get the autofix state from a pr id (#696)
Browse files Browse the repository at this point in the history
Introduces the `/v1/automation/autofix/state/pr` endpoint which will return an autofix fix state for a git provider + pr id.

We store a mappings for that as a db table when prs are created. This endpoint will be used for the github webhook flows.
  • Loading branch information
jennmueng authored May 15, 2024
1 parent 3946552 commit 133f8d3
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 18 deletions.
47 changes: 47 additions & 0 deletions src/migrations/versions/ab24602a9f68_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Migration
Revision ID: ab24602a9f68
Revises: b6cca7c6d99c
Create Date: 2024-05-10 21:42:27.125458
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "ab24602a9f68"
down_revision = "b6cca7c6d99c"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"autofix_pr_id_to_run_id",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("provider", sa.String(), nullable=False),
sa.Column("pr_id", sa.BigInteger(), nullable=False),
sa.Column("run_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["run_id"],
["run_state.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("provider", "pr_id", "run_id"),
)
with op.batch_alter_table("autofix_pr_id_to_run_id", schema=None) as batch_op:
batch_op.create_index(
"ix_autofix_pr_id_to_run_id_provider_pr_id", ["provider", "pr_id"], unique=False
)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("autofix_pr_id_to_run_id", schema=None) as batch_op:
batch_op.drop_index("ix_autofix_pr_id_to_run_id_provider_pr_id")

op.drop_table("autofix_pr_id_to_run_id")
# ### end Alembic commands ###
14 changes: 14 additions & 0 deletions src/seer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from celery_app.config import CeleryQueues
from seer.automation.autofix.models import (
AutofixEndpointResponse,
AutofixPrIdRequest,
AutofixRequest,
AutofixStateRequest,
AutofixStateResponse,
Expand All @@ -16,6 +17,7 @@
from seer.automation.autofix.tasks import (
check_and_mark_if_timed_out,
get_autofix_state,
get_autofix_state_from_pr_id,
run_autofix_create_pr,
run_autofix_execution,
run_autofix_root_cause,
Expand Down Expand Up @@ -202,6 +204,18 @@ def get_autofix_state_endpoint(data: AutofixStateRequest) -> AutofixStateRespons
)


@json_api("/v1/automation/autofix/state/pr")
def get_autofix_state_from_pr_endpoint(data: AutofixPrIdRequest) -> AutofixStateResponse:
state = get_autofix_state_from_pr_id(data.provider, data.pr_id)

if state:
cur = state.get()
return AutofixStateResponse(
group_id=cur.request.issue.id, state=cur.model_dump(mode="json")
)
return AutofixStateResponse(group_id=None, state=None)


@app.route("/health/live", methods=["GET"])
def health_check():
from seer.inference_models import models_loading_status
Expand Down
12 changes: 11 additions & 1 deletion src/seer/automation/autofix/autofix_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from seer.automation.pipeline import PipelineContext
from seer.automation.state import State
from seer.automation.utils import get_embedding_model
from seer.db import DbPrIdToAutofixRunIdMapping, Session
from seer.rpc import RpcClient


Expand Down Expand Up @@ -285,9 +286,18 @@ def commit_changes(self, repo_id: int | None = None):
pr = repo_client.create_pr_from_branch(branch_ref, pr_title, pr_description)

change_state.pull_request = CommittedPullRequestDetails(
pr_number=pr.number, pr_url=pr.html_url
pr_number=pr.number, pr_url=pr.html_url, pr_id=pr.id
)

with Session() as session:
pr_id_mapping = DbPrIdToAutofixRunIdMapping(
provider=repo_info.provider,
pr_id=pr.id,
run_id=state.run_id,
)
session.add(pr_id_mapping)
session.commit()

def _get_org_slug(self, organization_id: int) -> str | None:
slug: str | None = None
try:
Expand Down
10 changes: 8 additions & 2 deletions src/seer/automation/autofix/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class SuggestedFixRootCauseSelection(BaseModel):
class CommittedPullRequestDetails(BaseModel):
pr_number: int
pr_url: str
pr_id: Optional[int] = None


class CodebaseChange(BaseModel):
Expand Down Expand Up @@ -180,9 +181,14 @@ class AutofixStateRequest(BaseModel):
group_id: int


class AutofixPrIdRequest(BaseModel):
provider: str
pr_id: int


class AutofixStateResponse(BaseModel):
group_id: int
state: dict | None
group_id: Optional[int]
state: Optional[dict]


class AutofixCompleteArgs(BaseModel):
Expand Down
23 changes: 22 additions & 1 deletion src/seer/automation/autofix/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,32 @@
from seer.automation.autofix.steps.root_cause_step import RootCauseStep, RootCauseStepRequest
from seer.automation.autofix.utils import get_sentry_client
from seer.automation.models import InitializationError
from seer.db import DbRunState, Session
from seer.automation.utils import process_repo_provider
from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session

logger = logging.getLogger("autofix")


def get_autofix_state_from_pr_id(provider: str, pr_id: int) -> ContinuationState | None:
with Session() as session:
run_state = (
session.query(DbRunState)
.join(DbPrIdToAutofixRunIdMapping, DbPrIdToAutofixRunIdMapping.run_id == DbRunState.id)
.filter(
DbPrIdToAutofixRunIdMapping.provider == process_repo_provider(provider),
DbPrIdToAutofixRunIdMapping.pr_id == pr_id,
)
.order_by(DbRunState.id.desc())
.first()
)
if run_state is None:
return None

continuation = ContinuationState.from_id(run_state.id, AutofixContinuation)

return continuation


def get_autofix_state(group_id: int) -> ContinuationState | None:
with Session() as session:
run_state = (
Expand Down
6 changes: 3 additions & 3 deletions src/seer/automation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from pydantic_xml import BaseXmlModel
from typing_extensions import TypedDict

from seer.automation.utils import process_repo_provider


class StacktraceFrame(BaseModel):
model_config = ConfigDict(
Expand Down Expand Up @@ -193,9 +195,7 @@ def full_name(self):
@field_validator("provider", mode="after")
@classmethod
def validate_provider(cls, provider: str):
cleaned_provider = provider
if provider.startswith("integrations:"):
cleaned_provider = provider.split(":")[1]
cleaned_provider = process_repo_provider(provider)

if cleaned_provider != "github":
raise ValueError(f"Provider {cleaned_provider} is not supported.")
Expand Down
6 changes: 6 additions & 0 deletions src/seer/automation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,9 @@ def get_embedding_model():

def make_done_signal(id: str | int) -> str:
return f"done:{id}"


def process_repo_provider(provider: str) -> str:
if provider.startswith("integrations:"):
return provider.split(":")[1]
return provider
12 changes: 12 additions & 0 deletions src/seer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ class DbRunState(Base):
value: Mapped[dict] = mapped_column(JSON, nullable=False)


class DbPrIdToAutofixRunIdMapping(Base):
__tablename__ = "autofix_pr_id_to_run_id"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
provider: Mapped[str] = mapped_column(String, nullable=False)
pr_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
run_id: Mapped[int] = mapped_column(ForeignKey(DbRunState.id), nullable=False)
__table_args__ = (
UniqueConstraint("provider", "pr_id", "run_id"),
Index("ix_autofix_pr_id_to_run_id_provider_pr_id", "provider", "pr_id"),
)


class DbGroupingRecord(Base):
__tablename__ = "grouping_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
Expand Down
114 changes: 103 additions & 11 deletions tests/automation/autofix/test_autofix_context.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import unittest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

from johen import generate

from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.models import AutofixContinuation, AutofixRequest
from seer.automation.codebase.models import QueryResultDocumentChunk
from seer.automation.models import (
EventDetails,
ExceptionDetails,
IssueDetails,
SentryEventData,
Stacktrace,
StacktraceFrame,
from seer.automation.autofix.models import (
AutofixContinuation,
AutofixRequest,
AutofixStatus,
ChangesStep,
CodebaseChange,
CodebaseState,
StepType,
)
from seer.automation.state import LocalMemoryState
from seer.automation.autofix.state import ContinuationState
from seer.automation.codebase.models import QueryResultDocumentChunk, RepositoryInfo
from seer.automation.models import FileChange, IssueDetails, SentryEventData
from seer.automation.state import DbState, LocalMemoryState
from seer.db import DbPrIdToAutofixRunIdMapping, Session


class TestAutofixContext(unittest.TestCase):
Expand Down Expand Up @@ -56,5 +59,94 @@ def test_multi_codebase_query(self):
self.assertEqual(result_chunks, sorted_chunks)


class TestAutofixContextPrCommit(unittest.TestCase):
def setUp(self):
error_event = next(generate(SentryEventData))
self.state = ContinuationState.new(
AutofixContinuation(
request=AutofixRequest(
organization_id=1,
project_id=1,
repos=[],
issue=IssueDetails(id=0, title="", events=[error_event], short_id="ISSUE_1"),
),
)
)
self.autofix_context = AutofixContext(
self.state, MagicMock(), MagicMock(), skip_loading_codebase=True
)
self.autofix_context._get_org_slug = MagicMock(return_value="slug")

@patch(
"seer.automation.autofix.autofix_context.CodebaseIndex.get_repo_info_from_db",
return_value=RepositoryInfo(
id=1,
organization=1,
project=1,
provider="github",
external_slug="getsentry/slug",
external_id="1",
),
)
@patch("seer.automation.autofix.autofix_context.RepoClient")
def test_commit_changes(self, mock_RepoClient, mock_get_repo_info_from_db):
mock_repo_client = MagicMock()
mock_repo_client.create_branch_from_changes.return_value = "test_branch"
mock_pr = MagicMock(number=1, html_url="http://test.com", id=123)
mock_repo_client.create_pr_from_branch.return_value = mock_pr

mock_RepoClient.from_repo_info.return_value = mock_repo_client

with self.state.update() as cur:
cur.codebases = {
1: CodebaseState(
repo_id=1,
namespace_id=1,
file_changes=[
FileChange(
path="test.py",
reference_snippet="test",
change_type="edit",
new_snippet="test2",
description="test",
)
],
)
}
cur.steps = [
ChangesStep(
id="changes",
title="changes_title",
type=StepType.CHANGES,
status=AutofixStatus.PENDING,
index=0,
changes=[
CodebaseChange(
repo_id=1,
repo_name="test",
title="This is the title",
description="This is the description",
)
],
)
]

self.autofix_context.commit_changes()

mock_repo_client.create_pr_from_branch.assert_called_once_with(
"test_branch",
"🤖 This is the title",
"👋 Hi there! This PR was automatically generated 🤖\n\n\nFixes [ISSUE_1](https://sentry.io/organizations/slug/issues/0/)\n\nThis is the description\n\n### 📣 Instructions for the reviewer which is you, yes **you**:\n- **If these changes were incorrect, please close this PR and comment explaining why.**\n- **If these changes were incomplete, please continue working on this PR then merge it.**\n- **If you are feeling confident in my changes, please merge this PR.**\n\nThis will greatly help us improve the autofix system. Thank you! 🙏\n\nIf there are any questions, please reach out to the [AI/ML Team](https://github.com/orgs/getsentry/teams/machine-learning-ai) on [#proj-autofix](https://sentry.slack.com/archives/C06904P7Z6E)\n\n### 🤓 Stats for the nerds:\nPrompt tokens: **0**\nCompletion tokens: **0**\nTotal tokens: **0**",
)

with Session() as session:
pr_mapping = session.query(DbPrIdToAutofixRunIdMapping).filter_by(pr_id=123).first()
self.assertIsNotNone(pr_mapping)

if pr_mapping:
cur = self.state.get()
self.assertEqual(pr_mapping.run_id, cur.run_id)


if __name__ == "__main__":
unittest.main()
33 changes: 33 additions & 0 deletions tests/automation/autofix/test_autofix_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

from johen import generate

from seer.automation.autofix.models import AutofixContinuation
from seer.automation.autofix.tasks import get_autofix_state_from_pr_id
from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session


class TestGetStateFromPr(unittest.TestCase):
def test_successful_state_mapping(self):
state = next(generate(AutofixContinuation))
with Session() as session:
session.add(DbRunState(id=1, group_id=1, value=state.model_dump(mode="json")))
session.flush()
session.add(DbPrIdToAutofixRunIdMapping(provider="test", pr_id=1, run_id=1))
session.commit()

retrieved_state = get_autofix_state_from_pr_id("test", 1)
self.assertIsNotNone(retrieved_state)
if retrieved_state is not None:
self.assertEqual(retrieved_state.get(), state)

def test_no_state_mapping(self):
state = next(generate(AutofixContinuation))
with Session() as session:
session.add(DbRunState(id=1, group_id=1, value=state.model_dump(mode="json")))
session.flush()
session.add(DbPrIdToAutofixRunIdMapping(provider="test", pr_id=1, run_id=1))
session.commit()

retrieved_state = get_autofix_state_from_pr_id("test", 2)
self.assertIsNone(retrieved_state)

0 comments on commit 133f8d3

Please sign in to comment.