Skip to content

Commit 133f8d3

Browse files
authored
feat(autofix): Endpoint to get the autofix state from a pr id (#696)
Introduces the `/v1/automation/autofix/state/pr` endpoint which will return an autofix fix state for a git provider + pr id. We store a mappings for that as a db table when prs are created. This endpoint will be used for the github webhook flows.
1 parent 3946552 commit 133f8d3

File tree

10 files changed

+259
-18
lines changed

10 files changed

+259
-18
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Migration
2+
3+
Revision ID: ab24602a9f68
4+
Revises: b6cca7c6d99c
5+
Create Date: 2024-05-10 21:42:27.125458
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = "ab24602a9f68"
13+
down_revision = "b6cca7c6d99c"
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade():
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.create_table(
21+
"autofix_pr_id_to_run_id",
22+
sa.Column("id", sa.Integer(), nullable=False),
23+
sa.Column("provider", sa.String(), nullable=False),
24+
sa.Column("pr_id", sa.BigInteger(), nullable=False),
25+
sa.Column("run_id", sa.Integer(), nullable=False),
26+
sa.ForeignKeyConstraint(
27+
["run_id"],
28+
["run_state.id"],
29+
),
30+
sa.PrimaryKeyConstraint("id"),
31+
sa.UniqueConstraint("provider", "pr_id", "run_id"),
32+
)
33+
with op.batch_alter_table("autofix_pr_id_to_run_id", schema=None) as batch_op:
34+
batch_op.create_index(
35+
"ix_autofix_pr_id_to_run_id_provider_pr_id", ["provider", "pr_id"], unique=False
36+
)
37+
38+
# ### end Alembic commands ###
39+
40+
41+
def downgrade():
42+
# ### commands auto generated by Alembic - please adjust! ###
43+
with op.batch_alter_table("autofix_pr_id_to_run_id", schema=None) as batch_op:
44+
batch_op.drop_index("ix_autofix_pr_id_to_run_id_provider_pr_id")
45+
46+
op.drop_table("autofix_pr_id_to_run_id")
47+
# ### end Alembic commands ###

src/seer/app.py

+14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from celery_app.config import CeleryQueues
88
from seer.automation.autofix.models import (
99
AutofixEndpointResponse,
10+
AutofixPrIdRequest,
1011
AutofixRequest,
1112
AutofixStateRequest,
1213
AutofixStateResponse,
@@ -16,6 +17,7 @@
1617
from seer.automation.autofix.tasks import (
1718
check_and_mark_if_timed_out,
1819
get_autofix_state,
20+
get_autofix_state_from_pr_id,
1921
run_autofix_create_pr,
2022
run_autofix_execution,
2123
run_autofix_root_cause,
@@ -202,6 +204,18 @@ def get_autofix_state_endpoint(data: AutofixStateRequest) -> AutofixStateRespons
202204
)
203205

204206

207+
@json_api("/v1/automation/autofix/state/pr")
208+
def get_autofix_state_from_pr_endpoint(data: AutofixPrIdRequest) -> AutofixStateResponse:
209+
state = get_autofix_state_from_pr_id(data.provider, data.pr_id)
210+
211+
if state:
212+
cur = state.get()
213+
return AutofixStateResponse(
214+
group_id=cur.request.issue.id, state=cur.model_dump(mode="json")
215+
)
216+
return AutofixStateResponse(group_id=None, state=None)
217+
218+
205219
@app.route("/health/live", methods=["GET"])
206220
def health_check():
207221
from seer.inference_models import models_loading_status

src/seer/automation/autofix/autofix_context.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from seer.automation.pipeline import PipelineContext
2222
from seer.automation.state import State
2323
from seer.automation.utils import get_embedding_model
24+
from seer.db import DbPrIdToAutofixRunIdMapping, Session
2425
from seer.rpc import RpcClient
2526

2627

@@ -285,9 +286,18 @@ def commit_changes(self, repo_id: int | None = None):
285286
pr = repo_client.create_pr_from_branch(branch_ref, pr_title, pr_description)
286287

287288
change_state.pull_request = CommittedPullRequestDetails(
288-
pr_number=pr.number, pr_url=pr.html_url
289+
pr_number=pr.number, pr_url=pr.html_url, pr_id=pr.id
289290
)
290291

292+
with Session() as session:
293+
pr_id_mapping = DbPrIdToAutofixRunIdMapping(
294+
provider=repo_info.provider,
295+
pr_id=pr.id,
296+
run_id=state.run_id,
297+
)
298+
session.add(pr_id_mapping)
299+
session.commit()
300+
291301
def _get_org_slug(self, organization_id: int) -> str | None:
292302
slug: str | None = None
293303
try:

src/seer/automation/autofix/models.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class SuggestedFixRootCauseSelection(BaseModel):
8787
class CommittedPullRequestDetails(BaseModel):
8888
pr_number: int
8989
pr_url: str
90+
pr_id: Optional[int] = None
9091

9192

9293
class CodebaseChange(BaseModel):
@@ -180,9 +181,14 @@ class AutofixStateRequest(BaseModel):
180181
group_id: int
181182

182183

184+
class AutofixPrIdRequest(BaseModel):
185+
provider: str
186+
pr_id: int
187+
188+
183189
class AutofixStateResponse(BaseModel):
184-
group_id: int
185-
state: dict | None
190+
group_id: Optional[int]
191+
state: Optional[dict]
186192

187193

188194
class AutofixCompleteArgs(BaseModel):

src/seer/automation/autofix/tasks.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,32 @@
2828
from seer.automation.autofix.steps.root_cause_step import RootCauseStep, RootCauseStepRequest
2929
from seer.automation.autofix.utils import get_sentry_client
3030
from seer.automation.models import InitializationError
31-
from seer.db import DbRunState, Session
31+
from seer.automation.utils import process_repo_provider
32+
from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session
3233

3334
logger = logging.getLogger("autofix")
3435

3536

37+
def get_autofix_state_from_pr_id(provider: str, pr_id: int) -> ContinuationState | None:
38+
with Session() as session:
39+
run_state = (
40+
session.query(DbRunState)
41+
.join(DbPrIdToAutofixRunIdMapping, DbPrIdToAutofixRunIdMapping.run_id == DbRunState.id)
42+
.filter(
43+
DbPrIdToAutofixRunIdMapping.provider == process_repo_provider(provider),
44+
DbPrIdToAutofixRunIdMapping.pr_id == pr_id,
45+
)
46+
.order_by(DbRunState.id.desc())
47+
.first()
48+
)
49+
if run_state is None:
50+
return None
51+
52+
continuation = ContinuationState.from_id(run_state.id, AutofixContinuation)
53+
54+
return continuation
55+
56+
3657
def get_autofix_state(group_id: int) -> ContinuationState | None:
3758
with Session() as session:
3859
run_state = (

src/seer/automation/models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pydantic_xml import BaseXmlModel
2020
from typing_extensions import TypedDict
2121

22+
from seer.automation.utils import process_repo_provider
23+
2224

2325
class StacktraceFrame(BaseModel):
2426
model_config = ConfigDict(
@@ -193,9 +195,7 @@ def full_name(self):
193195
@field_validator("provider", mode="after")
194196
@classmethod
195197
def validate_provider(cls, provider: str):
196-
cleaned_provider = provider
197-
if provider.startswith("integrations:"):
198-
cleaned_provider = provider.split(":")[1]
198+
cleaned_provider = process_repo_provider(provider)
199199

200200
if cleaned_provider != "github":
201201
raise ValueError(f"Provider {cleaned_provider} is not supported.")

src/seer/automation/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,9 @@ def get_embedding_model():
6767

6868
def make_done_signal(id: str | int) -> str:
6969
return f"done:{id}"
70+
71+
72+
def process_repo_provider(provider: str) -> str:
73+
if provider.startswith("integrations:"):
74+
return provider.split(":")[1]
75+
return provider

src/seer/db.py

+12
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,18 @@ class DbRunState(Base):
222222
value: Mapped[dict] = mapped_column(JSON, nullable=False)
223223

224224

225+
class DbPrIdToAutofixRunIdMapping(Base):
226+
__tablename__ = "autofix_pr_id_to_run_id"
227+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
228+
provider: Mapped[str] = mapped_column(String, nullable=False)
229+
pr_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
230+
run_id: Mapped[int] = mapped_column(ForeignKey(DbRunState.id), nullable=False)
231+
__table_args__ = (
232+
UniqueConstraint("provider", "pr_id", "run_id"),
233+
Index("ix_autofix_pr_id_to_run_id_provider_pr_id", "provider", "pr_id"),
234+
)
235+
236+
225237
class DbGroupingRecord(Base):
226238
__tablename__ = "grouping_records"
227239
id: Mapped[int] = mapped_column(Integer, primary_key=True)

tests/automation/autofix/test_autofix_context.py

+103-11
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import unittest
2-
from unittest.mock import MagicMock
2+
from unittest.mock import MagicMock, patch
33

44
from johen import generate
55

66
from seer.automation.autofix.autofix_context import AutofixContext
7-
from seer.automation.autofix.models import AutofixContinuation, AutofixRequest
8-
from seer.automation.codebase.models import QueryResultDocumentChunk
9-
from seer.automation.models import (
10-
EventDetails,
11-
ExceptionDetails,
12-
IssueDetails,
13-
SentryEventData,
14-
Stacktrace,
15-
StacktraceFrame,
7+
from seer.automation.autofix.models import (
8+
AutofixContinuation,
9+
AutofixRequest,
10+
AutofixStatus,
11+
ChangesStep,
12+
CodebaseChange,
13+
CodebaseState,
14+
StepType,
1615
)
17-
from seer.automation.state import LocalMemoryState
16+
from seer.automation.autofix.state import ContinuationState
17+
from seer.automation.codebase.models import QueryResultDocumentChunk, RepositoryInfo
18+
from seer.automation.models import FileChange, IssueDetails, SentryEventData
19+
from seer.automation.state import DbState, LocalMemoryState
20+
from seer.db import DbPrIdToAutofixRunIdMapping, Session
1821

1922

2023
class TestAutofixContext(unittest.TestCase):
@@ -56,5 +59,94 @@ def test_multi_codebase_query(self):
5659
self.assertEqual(result_chunks, sorted_chunks)
5760

5861

62+
class TestAutofixContextPrCommit(unittest.TestCase):
63+
def setUp(self):
64+
error_event = next(generate(SentryEventData))
65+
self.state = ContinuationState.new(
66+
AutofixContinuation(
67+
request=AutofixRequest(
68+
organization_id=1,
69+
project_id=1,
70+
repos=[],
71+
issue=IssueDetails(id=0, title="", events=[error_event], short_id="ISSUE_1"),
72+
),
73+
)
74+
)
75+
self.autofix_context = AutofixContext(
76+
self.state, MagicMock(), MagicMock(), skip_loading_codebase=True
77+
)
78+
self.autofix_context._get_org_slug = MagicMock(return_value="slug")
79+
80+
@patch(
81+
"seer.automation.autofix.autofix_context.CodebaseIndex.get_repo_info_from_db",
82+
return_value=RepositoryInfo(
83+
id=1,
84+
organization=1,
85+
project=1,
86+
provider="github",
87+
external_slug="getsentry/slug",
88+
external_id="1",
89+
),
90+
)
91+
@patch("seer.automation.autofix.autofix_context.RepoClient")
92+
def test_commit_changes(self, mock_RepoClient, mock_get_repo_info_from_db):
93+
mock_repo_client = MagicMock()
94+
mock_repo_client.create_branch_from_changes.return_value = "test_branch"
95+
mock_pr = MagicMock(number=1, html_url="http://test.com", id=123)
96+
mock_repo_client.create_pr_from_branch.return_value = mock_pr
97+
98+
mock_RepoClient.from_repo_info.return_value = mock_repo_client
99+
100+
with self.state.update() as cur:
101+
cur.codebases = {
102+
1: CodebaseState(
103+
repo_id=1,
104+
namespace_id=1,
105+
file_changes=[
106+
FileChange(
107+
path="test.py",
108+
reference_snippet="test",
109+
change_type="edit",
110+
new_snippet="test2",
111+
description="test",
112+
)
113+
],
114+
)
115+
}
116+
cur.steps = [
117+
ChangesStep(
118+
id="changes",
119+
title="changes_title",
120+
type=StepType.CHANGES,
121+
status=AutofixStatus.PENDING,
122+
index=0,
123+
changes=[
124+
CodebaseChange(
125+
repo_id=1,
126+
repo_name="test",
127+
title="This is the title",
128+
description="This is the description",
129+
)
130+
],
131+
)
132+
]
133+
134+
self.autofix_context.commit_changes()
135+
136+
mock_repo_client.create_pr_from_branch.assert_called_once_with(
137+
"test_branch",
138+
"🤖 This is the title",
139+
"👋 Hi there! This PR was automatically generated 🤖\n\n\nFixes [ISSUE_1](https://sentry.io/organizations/slug/issues/0/)\n\nThis is the description\n\n### 📣 Instructions for the reviewer which is you, yes **you**:\n- **If these changes were incorrect, please close this PR and comment explaining why.**\n- **If these changes were incomplete, please continue working on this PR then merge it.**\n- **If you are feeling confident in my changes, please merge this PR.**\n\nThis will greatly help us improve the autofix system. Thank you! 🙏\n\nIf there are any questions, please reach out to the [AI/ML Team](https://github.com/orgs/getsentry/teams/machine-learning-ai) on [#proj-autofix](https://sentry.slack.com/archives/C06904P7Z6E)\n\n### 🤓 Stats for the nerds:\nPrompt tokens: **0**\nCompletion tokens: **0**\nTotal tokens: **0**",
140+
)
141+
142+
with Session() as session:
143+
pr_mapping = session.query(DbPrIdToAutofixRunIdMapping).filter_by(pr_id=123).first()
144+
self.assertIsNotNone(pr_mapping)
145+
146+
if pr_mapping:
147+
cur = self.state.get()
148+
self.assertEqual(pr_mapping.run_id, cur.run_id)
149+
150+
59151
if __name__ == "__main__":
60152
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
3+
from johen import generate
4+
5+
from seer.automation.autofix.models import AutofixContinuation
6+
from seer.automation.autofix.tasks import get_autofix_state_from_pr_id
7+
from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session
8+
9+
10+
class TestGetStateFromPr(unittest.TestCase):
11+
def test_successful_state_mapping(self):
12+
state = next(generate(AutofixContinuation))
13+
with Session() as session:
14+
session.add(DbRunState(id=1, group_id=1, value=state.model_dump(mode="json")))
15+
session.flush()
16+
session.add(DbPrIdToAutofixRunIdMapping(provider="test", pr_id=1, run_id=1))
17+
session.commit()
18+
19+
retrieved_state = get_autofix_state_from_pr_id("test", 1)
20+
self.assertIsNotNone(retrieved_state)
21+
if retrieved_state is not None:
22+
self.assertEqual(retrieved_state.get(), state)
23+
24+
def test_no_state_mapping(self):
25+
state = next(generate(AutofixContinuation))
26+
with Session() as session:
27+
session.add(DbRunState(id=1, group_id=1, value=state.model_dump(mode="json")))
28+
session.flush()
29+
session.add(DbPrIdToAutofixRunIdMapping(provider="test", pr_id=1, run_id=1))
30+
session.commit()
31+
32+
retrieved_state = get_autofix_state_from_pr_id("test", 2)
33+
self.assertIsNone(retrieved_state)

0 commit comments

Comments
 (0)