diff --git a/src/migrations/versions/ab24602a9f68_migration.py b/src/migrations/versions/ab24602a9f68_migration.py new file mode 100644 index 000000000..5b11e98f6 --- /dev/null +++ b/src/migrations/versions/ab24602a9f68_migration.py @@ -0,0 +1,47 @@ +"""Migration + +Revision ID: ab24602a9f68 +Revises: b6cca7c6d99c +Create Date: 2024-05-10 21:42:27.125458 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "ab24602a9f68" +down_revision = "b6cca7c6d99c" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "autofix_pr_id_to_run_id", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("provider", sa.String(), nullable=False), + sa.Column("pr_id", sa.BigInteger(), nullable=False), + sa.Column("run_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["run_id"], + ["run_state.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("provider", "pr_id", "run_id"), + ) + with op.batch_alter_table("autofix_pr_id_to_run_id", schema=None) as batch_op: + batch_op.create_index( + "ix_autofix_pr_id_to_run_id_provider_pr_id", ["provider", "pr_id"], unique=False + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("autofix_pr_id_to_run_id", schema=None) as batch_op: + batch_op.drop_index("ix_autofix_pr_id_to_run_id_provider_pr_id") + + op.drop_table("autofix_pr_id_to_run_id") + # ### end Alembic commands ### diff --git a/src/seer/app.py b/src/seer/app.py index 3ac32df6e..dfcc2abee 100644 --- a/src/seer/app.py +++ b/src/seer/app.py @@ -7,6 +7,7 @@ from celery_app.config import CeleryQueues from seer.automation.autofix.models import ( AutofixEndpointResponse, + AutofixPrIdRequest, AutofixRequest, AutofixStateRequest, AutofixStateResponse, @@ -16,6 +17,7 @@ from seer.automation.autofix.tasks import ( check_and_mark_if_timed_out, get_autofix_state, + get_autofix_state_from_pr_id, run_autofix_create_pr, run_autofix_execution, run_autofix_root_cause, @@ -202,6 +204,18 @@ def get_autofix_state_endpoint(data: AutofixStateRequest) -> AutofixStateRespons ) +@json_api("/v1/automation/autofix/state/pr") +def get_autofix_state_from_pr_endpoint(data: AutofixPrIdRequest) -> AutofixStateResponse: + state = get_autofix_state_from_pr_id(data.provider, data.pr_id) + + if state: + cur = state.get() + return AutofixStateResponse( + group_id=cur.request.issue.id, state=cur.model_dump(mode="json") + ) + return AutofixStateResponse(group_id=None, state=None) + + @app.route("/health/live", methods=["GET"]) def health_check(): from seer.inference_models import models_loading_status diff --git a/src/seer/automation/autofix/autofix_context.py b/src/seer/automation/autofix/autofix_context.py index 43a45b3c4..20feaa454 100644 --- a/src/seer/automation/autofix/autofix_context.py +++ b/src/seer/automation/autofix/autofix_context.py @@ -21,6 +21,7 @@ from seer.automation.pipeline import PipelineContext from seer.automation.state import State from seer.automation.utils import get_embedding_model +from seer.db import DbPrIdToAutofixRunIdMapping, Session from seer.rpc import RpcClient @@ -285,9 +286,18 @@ def commit_changes(self, repo_id: int | None = None): pr = repo_client.create_pr_from_branch(branch_ref, pr_title, pr_description) change_state.pull_request = CommittedPullRequestDetails( - pr_number=pr.number, pr_url=pr.html_url + pr_number=pr.number, pr_url=pr.html_url, pr_id=pr.id ) + with Session() as session: + pr_id_mapping = DbPrIdToAutofixRunIdMapping( + provider=repo_info.provider, + pr_id=pr.id, + run_id=state.run_id, + ) + session.add(pr_id_mapping) + session.commit() + def _get_org_slug(self, organization_id: int) -> str | None: slug: str | None = None try: diff --git a/src/seer/automation/autofix/models.py b/src/seer/automation/autofix/models.py index 0334730c3..a276fcb33 100644 --- a/src/seer/automation/autofix/models.py +++ b/src/seer/automation/autofix/models.py @@ -87,6 +87,7 @@ class SuggestedFixRootCauseSelection(BaseModel): class CommittedPullRequestDetails(BaseModel): pr_number: int pr_url: str + pr_id: Optional[int] = None class CodebaseChange(BaseModel): @@ -180,9 +181,14 @@ class AutofixStateRequest(BaseModel): group_id: int +class AutofixPrIdRequest(BaseModel): + provider: str + pr_id: int + + class AutofixStateResponse(BaseModel): - group_id: int - state: dict | None + group_id: Optional[int] + state: Optional[dict] class AutofixCompleteArgs(BaseModel): diff --git a/src/seer/automation/autofix/tasks.py b/src/seer/automation/autofix/tasks.py index d28396f8f..1e8324e46 100644 --- a/src/seer/automation/autofix/tasks.py +++ b/src/seer/automation/autofix/tasks.py @@ -28,11 +28,32 @@ from seer.automation.autofix.steps.root_cause_step import RootCauseStep, RootCauseStepRequest from seer.automation.autofix.utils import get_sentry_client from seer.automation.models import InitializationError -from seer.db import DbRunState, Session +from seer.automation.utils import process_repo_provider +from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session logger = logging.getLogger("autofix") +def get_autofix_state_from_pr_id(provider: str, pr_id: int) -> ContinuationState | None: + with Session() as session: + run_state = ( + session.query(DbRunState) + .join(DbPrIdToAutofixRunIdMapping, DbPrIdToAutofixRunIdMapping.run_id == DbRunState.id) + .filter( + DbPrIdToAutofixRunIdMapping.provider == process_repo_provider(provider), + DbPrIdToAutofixRunIdMapping.pr_id == pr_id, + ) + .order_by(DbRunState.id.desc()) + .first() + ) + if run_state is None: + return None + + continuation = ContinuationState.from_id(run_state.id, AutofixContinuation) + + return continuation + + def get_autofix_state(group_id: int) -> ContinuationState | None: with Session() as session: run_state = ( diff --git a/src/seer/automation/models.py b/src/seer/automation/models.py index e3bf2bd14..350e484c2 100644 --- a/src/seer/automation/models.py +++ b/src/seer/automation/models.py @@ -19,6 +19,8 @@ from pydantic_xml import BaseXmlModel from typing_extensions import TypedDict +from seer.automation.utils import process_repo_provider + class StacktraceFrame(BaseModel): model_config = ConfigDict( @@ -193,9 +195,7 @@ def full_name(self): @field_validator("provider", mode="after") @classmethod def validate_provider(cls, provider: str): - cleaned_provider = provider - if provider.startswith("integrations:"): - cleaned_provider = provider.split(":")[1] + cleaned_provider = process_repo_provider(provider) if cleaned_provider != "github": raise ValueError(f"Provider {cleaned_provider} is not supported.") diff --git a/src/seer/automation/utils.py b/src/seer/automation/utils.py index 305b86842..f7620ca29 100644 --- a/src/seer/automation/utils.py +++ b/src/seer/automation/utils.py @@ -67,3 +67,9 @@ def get_embedding_model(): def make_done_signal(id: str | int) -> str: return f"done:{id}" + + +def process_repo_provider(provider: str) -> str: + if provider.startswith("integrations:"): + return provider.split(":")[1] + return provider diff --git a/src/seer/db.py b/src/seer/db.py index 13299726b..a2efc00c4 100644 --- a/src/seer/db.py +++ b/src/seer/db.py @@ -222,6 +222,18 @@ class DbRunState(Base): value: Mapped[dict] = mapped_column(JSON, nullable=False) +class DbPrIdToAutofixRunIdMapping(Base): + __tablename__ = "autofix_pr_id_to_run_id" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + provider: Mapped[str] = mapped_column(String, nullable=False) + pr_id: Mapped[int] = mapped_column(BigInteger, nullable=False) + run_id: Mapped[int] = mapped_column(ForeignKey(DbRunState.id), nullable=False) + __table_args__ = ( + UniqueConstraint("provider", "pr_id", "run_id"), + Index("ix_autofix_pr_id_to_run_id_provider_pr_id", "provider", "pr_id"), + ) + + class DbGroupingRecord(Base): __tablename__ = "grouping_records" id: Mapped[int] = mapped_column(Integer, primary_key=True) diff --git a/tests/automation/autofix/test_autofix_context.py b/tests/automation/autofix/test_autofix_context.py index 825621860..ac3cd6966 100644 --- a/tests/automation/autofix/test_autofix_context.py +++ b/tests/automation/autofix/test_autofix_context.py @@ -1,20 +1,23 @@ import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from johen import generate from seer.automation.autofix.autofix_context import AutofixContext -from seer.automation.autofix.models import AutofixContinuation, AutofixRequest -from seer.automation.codebase.models import QueryResultDocumentChunk -from seer.automation.models import ( - EventDetails, - ExceptionDetails, - IssueDetails, - SentryEventData, - Stacktrace, - StacktraceFrame, +from seer.automation.autofix.models import ( + AutofixContinuation, + AutofixRequest, + AutofixStatus, + ChangesStep, + CodebaseChange, + CodebaseState, + StepType, ) -from seer.automation.state import LocalMemoryState +from seer.automation.autofix.state import ContinuationState +from seer.automation.codebase.models import QueryResultDocumentChunk, RepositoryInfo +from seer.automation.models import FileChange, IssueDetails, SentryEventData +from seer.automation.state import DbState, LocalMemoryState +from seer.db import DbPrIdToAutofixRunIdMapping, Session class TestAutofixContext(unittest.TestCase): @@ -56,5 +59,94 @@ def test_multi_codebase_query(self): self.assertEqual(result_chunks, sorted_chunks) +class TestAutofixContextPrCommit(unittest.TestCase): + def setUp(self): + error_event = next(generate(SentryEventData)) + self.state = ContinuationState.new( + AutofixContinuation( + request=AutofixRequest( + organization_id=1, + project_id=1, + repos=[], + issue=IssueDetails(id=0, title="", events=[error_event], short_id="ISSUE_1"), + ), + ) + ) + self.autofix_context = AutofixContext( + self.state, MagicMock(), MagicMock(), skip_loading_codebase=True + ) + self.autofix_context._get_org_slug = MagicMock(return_value="slug") + + @patch( + "seer.automation.autofix.autofix_context.CodebaseIndex.get_repo_info_from_db", + return_value=RepositoryInfo( + id=1, + organization=1, + project=1, + provider="github", + external_slug="getsentry/slug", + external_id="1", + ), + ) + @patch("seer.automation.autofix.autofix_context.RepoClient") + def test_commit_changes(self, mock_RepoClient, mock_get_repo_info_from_db): + mock_repo_client = MagicMock() + mock_repo_client.create_branch_from_changes.return_value = "test_branch" + mock_pr = MagicMock(number=1, html_url="http://test.com", id=123) + mock_repo_client.create_pr_from_branch.return_value = mock_pr + + mock_RepoClient.from_repo_info.return_value = mock_repo_client + + with self.state.update() as cur: + cur.codebases = { + 1: CodebaseState( + repo_id=1, + namespace_id=1, + file_changes=[ + FileChange( + path="test.py", + reference_snippet="test", + change_type="edit", + new_snippet="test2", + description="test", + ) + ], + ) + } + cur.steps = [ + ChangesStep( + id="changes", + title="changes_title", + type=StepType.CHANGES, + status=AutofixStatus.PENDING, + index=0, + changes=[ + CodebaseChange( + repo_id=1, + repo_name="test", + title="This is the title", + description="This is the description", + ) + ], + ) + ] + + self.autofix_context.commit_changes() + + mock_repo_client.create_pr_from_branch.assert_called_once_with( + "test_branch", + "🤖 This is the title", + "👋 Hi there! This PR was automatically generated 🤖\n\n\nFixes [ISSUE_1](https://sentry.io/organizations/slug/issues/0/)\n\nThis is the description\n\n### 📣 Instructions for the reviewer which is you, yes **you**:\n- **If these changes were incorrect, please close this PR and comment explaining why.**\n- **If these changes were incomplete, please continue working on this PR then merge it.**\n- **If you are feeling confident in my changes, please merge this PR.**\n\nThis will greatly help us improve the autofix system. Thank you! 🙏\n\nIf there are any questions, please reach out to the [AI/ML Team](https://github.com/orgs/getsentry/teams/machine-learning-ai) on [#proj-autofix](https://sentry.slack.com/archives/C06904P7Z6E)\n\n### 🤓 Stats for the nerds:\nPrompt tokens: **0**\nCompletion tokens: **0**\nTotal tokens: **0**", + ) + + with Session() as session: + pr_mapping = session.query(DbPrIdToAutofixRunIdMapping).filter_by(pr_id=123).first() + self.assertIsNotNone(pr_mapping) + + if pr_mapping: + cur = self.state.get() + self.assertEqual(pr_mapping.run_id, cur.run_id) + + if __name__ == "__main__": unittest.main() diff --git a/tests/automation/autofix/test_autofix_tasks.py b/tests/automation/autofix/test_autofix_tasks.py new file mode 100644 index 000000000..d8baca019 --- /dev/null +++ b/tests/automation/autofix/test_autofix_tasks.py @@ -0,0 +1,33 @@ +import unittest + +from johen import generate + +from seer.automation.autofix.models import AutofixContinuation +from seer.automation.autofix.tasks import get_autofix_state_from_pr_id +from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session + + +class TestGetStateFromPr(unittest.TestCase): + def test_successful_state_mapping(self): + state = next(generate(AutofixContinuation)) + with Session() as session: + session.add(DbRunState(id=1, group_id=1, value=state.model_dump(mode="json"))) + session.flush() + session.add(DbPrIdToAutofixRunIdMapping(provider="test", pr_id=1, run_id=1)) + session.commit() + + retrieved_state = get_autofix_state_from_pr_id("test", 1) + self.assertIsNotNone(retrieved_state) + if retrieved_state is not None: + self.assertEqual(retrieved_state.get(), state) + + def test_no_state_mapping(self): + state = next(generate(AutofixContinuation)) + with Session() as session: + session.add(DbRunState(id=1, group_id=1, value=state.model_dump(mode="json"))) + session.flush() + session.add(DbPrIdToAutofixRunIdMapping(provider="test", pr_id=1, run_id=1)) + session.commit() + + retrieved_state = get_autofix_state_from_pr_id("test", 2) + self.assertIsNone(retrieved_state)