Skip to content

Commit 12e1e81

Browse files
authored
fix getting pull requests that come from forks (#563)
* fix getting pull requests that come from forks * fix support for PRs from forks and reduce total requests * format * refactor pull_requests and logs --------- Co-authored-by: alanpatel <alanpatel@palantir.com>
1 parent ba3bedb commit 12e1e81

25 files changed

Lines changed: 5379 additions & 31 deletions

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ require (
4545
github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 // indirect
4646
github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f // indirect
4747
github.com/spf13/pflag v1.0.5 // indirect
48+
github.com/stretchr/objx v0.5.2 // indirect
4849
golang.org/x/net v0.31.0 // indirect
4950
golang.org/x/oauth2 v0.24.0 // indirect
5051
golang.org/x/sys v0.27.0 // indirect

pull/github_context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ func (ghc *GithubContext) Labels(ctx context.Context) ([]string, error) {
282282
func (ghc *GithubContext) IsTargeted(ctx context.Context) (bool, error) {
283283
ref := fmt.Sprintf("refs/heads/%s", ghc.pr.GetHead().GetRef())
284284

285-
prs, err := ListOpenPullRequestsForRef(ctx, ghc.client, ghc.owner, ghc.repo, ref)
285+
prs, err := GetAllOpenPullRequestsForRef(ctx, ghc.client.PullRequests, ghc.owner, ghc.repo, ref)
286286
if err != nil {
287287
return false, errors.Wrap(err, "failed to determine targeted status")
288288
}

pull/pull_requests.go

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018 Palantir Technologies, Inc.
1+
// Copyright 2024 Palantir Technologies, Inc.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -23,57 +23,121 @@ import (
2323
"github.com/rs/zerolog"
2424
)
2525

26-
// ListOpenPullRequestsForSHA returns all pull requests where the HEAD of the source branch
27-
// in the pull request matches the given SHA.
28-
func ListOpenPullRequestsForSHA(ctx context.Context, client *github.Client, owner, repoName, SHA string) ([]*github.PullRequest, error) {
29-
prs, _, err := client.PullRequests.ListPullRequestsWithCommit(ctx, owner, repoName, SHA, &github.ListOptions{
30-
// In practice, there should be at most 1-3 PRs for a given commit. In
31-
// exceptional cases, if there are more than 100 PRs, we'll only
32-
// consider the first 100 to avoid paging.
33-
PerPage: 100,
34-
})
35-
if err != nil {
36-
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName)
37-
}
26+
// GitHubPullRequestClient is an interface that wraps the methods used from the github.Client.
27+
type GitHubPullRequestClient interface {
28+
ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error)
29+
List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error)
30+
}
3831

32+
// getOpenPullRequestsForSHA returns all open pull requests where the HEAD of the source branch
33+
// matches the given SHA.
34+
func getOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
35+
logger := zerolog.Ctx(ctx)
3936
var results []*github.PullRequest
40-
for _, pr := range prs {
41-
if pr.GetState() == "open" && pr.GetHead().GetSHA() == SHA {
42-
results = append(results, pr)
37+
opts := &github.ListOptions{PerPage: 100}
38+
39+
for {
40+
prs, resp, err := client.ListPullRequestsWithCommit(ctx, owner, repo, sha, opts)
41+
if err != nil {
42+
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
43+
}
44+
45+
for _, pr := range prs {
46+
if pr.GetState() == "open" && pr.GetHead().GetSHA() == sha {
47+
logger.Debug().Msgf("found open pull request with sha %s", pr.GetHead().GetSHA())
48+
results = append(results, pr)
49+
}
50+
}
51+
52+
if resp.NextPage == 0 {
53+
break
4354
}
55+
opts.Page = resp.NextPage
4456
}
57+
4558
return results, nil
4659
}
4760

48-
func ListOpenPullRequestsForRef(ctx context.Context, client *github.Client, owner, repoName, ref string) ([]*github.PullRequest, error) {
61+
// ListAllOpenPullRequestsFilteredBySHA returns all open pull requests where the HEAD of the source branch
62+
// matches the given SHA by fetching all open PRs and filtering.
63+
func ListAllOpenPullRequestsFilteredBySHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
64+
logger := zerolog.Ctx(ctx)
4965
var results []*github.PullRequest
66+
opts := &github.PullRequestListOptions{
67+
State: "open",
68+
ListOptions: github.ListOptions{PerPage: 100},
69+
}
70+
71+
for {
72+
prs, resp, err := client.List(ctx, owner, repo, opts)
73+
if err != nil {
74+
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
75+
}
76+
77+
for _, pr := range prs {
78+
if pr.Head.GetSHA() == sha {
79+
logger.Debug().Msgf("found open pull request with sha %s", pr.Head.GetSHA())
80+
results = append(results, pr)
81+
}
82+
}
83+
84+
if resp.NextPage == 0 {
85+
break
86+
}
87+
opts.Page = resp.NextPage
88+
}
89+
90+
return results, nil
91+
}
92+
93+
// GetAllPossibleOpenPullRequestsForSHA attempts to find all open pull requests
94+
// associated with the given SHA using multiple methods in case we are dealing with a fork
95+
func GetAllPossibleOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
5096
logger := zerolog.Ctx(ctx)
5197

52-
ref = strings.TrimPrefix(ref, "refs/heads/")
98+
prs, err := getOpenPullRequestsForSHA(ctx, client, owner, repo, sha)
99+
if err != nil {
100+
return nil, errors.Wrap(err, "failed to get open pull requests matching the SHA")
101+
}
102+
103+
if len(prs) == 0 {
104+
logger.Debug().Msg("no pull requests found via commit association , searching all pull requests by SHA")
105+
prs, err = ListAllOpenPullRequestsFilteredBySHA(ctx, client, owner, repo, sha)
106+
if err != nil {
107+
return nil, errors.Wrap(err, "failed to list open pull requests matching the SHA")
108+
}
109+
}
110+
111+
return prs, nil
112+
}
53113

114+
// GetAllOpenPullRequestsForRef returns all open pull requests for a given base branch reference.
115+
func GetAllOpenPullRequestsForRef(ctx context.Context, client GitHubPullRequestClient, owner, repo, ref string) ([]*github.PullRequest, error) {
116+
logger := zerolog.Ctx(ctx)
117+
ref = strings.TrimPrefix(ref, "refs/heads/")
54118
opts := &github.PullRequestListOptions{
55-
State: "open",
56-
Base: ref, // Filter by base branch name
57-
ListOptions: github.ListOptions{
58-
PerPage: 100,
59-
},
119+
State: "open",
120+
Base: ref,
121+
ListOptions: github.ListOptions{PerPage: 100},
60122
}
61123

124+
var results []*github.PullRequest
62125
for {
63-
prs, resp, err := client.PullRequests.List(ctx, owner, repoName, opts)
126+
prs, resp, err := client.List(ctx, owner, repo, opts)
64127
if err != nil {
65-
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName)
128+
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
66129
}
130+
67131
for _, pr := range prs {
68132
logger.Debug().Msgf("found open pull request with base ref %s", pr.GetBase().GetRef())
69133
results = append(results, pr)
70134
}
135+
71136
if resp.NextPage == 0 {
72137
break
73138
}
74139
opts.Page = resp.NextPage
75140
}
76141

77142
return results, nil
78-
79143
}

pull/pull_requests_test.go

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// Copyright 2024 Palantir Technologies, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// pull_test.go
16+
17+
package pull
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/google/go-github/v66/github"
24+
"github.com/stretchr/testify/assert"
25+
"github.com/stretchr/testify/mock"
26+
)
27+
28+
type MockGitHubClient struct {
29+
mock.Mock
30+
}
31+
32+
func (m *MockGitHubClient) ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error) {
33+
args := m.Called(ctx, owner, repo, sha, opts)
34+
return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2)
35+
}
36+
37+
func (m *MockGitHubClient) List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error) {
38+
args := m.Called(ctx, owner, repo, opts)
39+
return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2)
40+
}
41+
42+
func TestGetOpenPullRequestsForSHA(t *testing.T) {
43+
mockClient := new(MockGitHubClient)
44+
ctx := context.Background()
45+
owner := "owner"
46+
repo := "repo"
47+
sha := "sha"
48+
49+
pr := &github.PullRequest{
50+
State: github.String("open"),
51+
Head: &github.PullRequestBranch{SHA: github.String(sha)},
52+
}
53+
54+
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)
55+
56+
prs, err := getOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
57+
assert.NoError(t, err)
58+
assert.Len(t, prs, 1)
59+
assert.Equal(t, sha, prs[0].GetHead().GetSHA())
60+
61+
mockClient.AssertExpectations(t)
62+
}
63+
64+
func TestListOpenPullRequestsForSHA(t *testing.T) {
65+
mockClient := new(MockGitHubClient)
66+
ctx := context.Background()
67+
owner := "owner"
68+
repo := "repo"
69+
sha := "sha"
70+
71+
pr := &github.PullRequest{
72+
State: github.String("open"),
73+
Head: &github.PullRequestBranch{SHA: github.String(sha)},
74+
}
75+
76+
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)
77+
78+
prs, err := ListAllOpenPullRequestsFilteredBySHA(ctx, mockClient, owner, repo, sha)
79+
assert.NoError(t, err)
80+
assert.Len(t, prs, 1)
81+
assert.Equal(t, sha, prs[0].GetHead().GetSHA())
82+
83+
mockClient.AssertExpectations(t)
84+
}
85+
86+
func TestGetAllPossibleOpenPullRequestsForSHA_FirstMethodReturnsResults(t *testing.T) {
87+
mockClient := new(MockGitHubClient)
88+
ctx := context.Background()
89+
owner := "owner"
90+
repo := "repo"
91+
sha := "sha"
92+
93+
pr := &github.PullRequest{
94+
State: github.String("open"),
95+
Head: &github.PullRequestBranch{SHA: github.String(sha)},
96+
}
97+
98+
// Mock the first method to return a valid pull request.
99+
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil).Once()
100+
// Mock the second method to not be called.
101+
mockClient.On("List", ctx, owner, repo, mock.Anything).Return(nil, nil, nil).Maybe()
102+
103+
prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
104+
assert.NoError(t, err)
105+
assert.Len(t, prs, 1)
106+
assert.Equal(t, sha, prs[0].GetHead().GetSHA())
107+
108+
mockClient.AssertExpectations(t)
109+
}
110+
111+
func TestGetAllPossibleOpenPullRequestsForSHA_SecondMethodReturnsResults(t *testing.T) {
112+
mockClient := new(MockGitHubClient)
113+
ctx := context.Background()
114+
owner := "owner"
115+
repo := "repo"
116+
sha := "sha"
117+
118+
pr := &github.PullRequest{
119+
State: github.String("open"),
120+
Head: &github.PullRequestBranch{SHA: github.String(sha)},
121+
}
122+
123+
// Mock the first method to return no results.
124+
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
125+
// Mock the second method to return a valid pull request.
126+
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil).Once()
127+
128+
prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
129+
assert.NoError(t, err)
130+
assert.Len(t, prs, 1)
131+
assert.Equal(t, sha, prs[0].GetHead().GetSHA())
132+
133+
mockClient.AssertExpectations(t)
134+
}
135+
136+
func TestGetAllPossibleOpenPullRequestsForSHA_NoResults(t *testing.T) {
137+
mockClient := new(MockGitHubClient)
138+
ctx := context.Background()
139+
owner := "owner"
140+
repo := "repo"
141+
sha := "sha"
142+
143+
// Mock both methods to return no results.
144+
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
145+
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
146+
147+
prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
148+
assert.NoError(t, err)
149+
assert.Len(t, prs, 0)
150+
151+
mockClient.AssertExpectations(t)
152+
}
153+
154+
func TestGetAllPossibleOpenPullRequestsForSHA_Errors(t *testing.T) {
155+
mockClient := new(MockGitHubClient)
156+
ctx := context.Background()
157+
owner := "owner"
158+
repo := "repo"
159+
sha := "sha"
160+
161+
// Mock the first method to return an error.
162+
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{}, assert.AnError).Once()
163+
// Mock the second method to not be called.
164+
mockClient.On("List", ctx, owner, repo, mock.Anything).Return(nil, nil, nil).Maybe()
165+
166+
prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
167+
assert.Error(t, err)
168+
assert.Nil(t, prs)
169+
170+
mockClient.AssertExpectations(t)
171+
}
172+
173+
func TestListOpenPullRequestsForRef(t *testing.T) {
174+
mockClient := new(MockGitHubClient)
175+
ctx := context.Background()
176+
owner := "owner"
177+
repo := "repo"
178+
ref := "refs/heads/main"
179+
180+
pr := &github.PullRequest{
181+
State: github.String("open"),
182+
Base: &github.PullRequestBranch{Ref: github.String("main")},
183+
}
184+
185+
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)
186+
187+
prs, err := GetAllOpenPullRequestsForRef(ctx, mockClient, owner, repo, ref)
188+
assert.NoError(t, err)
189+
assert.Len(t, prs, 1)
190+
assert.Equal(t, "main", prs[0].GetBase().GetRef())
191+
192+
mockClient.AssertExpectations(t)
193+
}

server/handler/check_run.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay
4040
}
4141

4242
repo := event.GetRepo()
43+
owner := repo.GetOwner().GetLogin()
44+
repoName := repo.GetName()
45+
4346
installationID := githubapp.GetInstallationIDFromEvent(&event)
4447

4548
ctx, logger := githubapp.PrepareRepoContext(ctx, installationID, repo)
@@ -56,8 +59,16 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay
5659

5760
prs := event.GetCheckRun().PullRequests
5861
if len(prs) == 0 {
59-
logger.Debug().Msg("Doing nothing since status change event affects no open pull requests")
60-
return nil
62+
logger.Debug().Msg("No pull requests associated with the check run, searching by SHA")
63+
// check runs on fork PRs do not have the PRs attached to the event so we need to filter all PRs by SHA
64+
prs, err = pull.ListAllOpenPullRequestsFilteredBySHA(ctx, client.PullRequests, owner, repoName, event.GetCheckRun().GetHeadSHA())
65+
if err != nil {
66+
return errors.Wrap(err, "failed to determine open pull requests matching the status context change")
67+
}
68+
if len(prs) == 0 {
69+
logger.Debug().Msg("No open pull requests found for the given SHA")
70+
return nil
71+
}
6172
}
6273

6374
for _, pr := range prs {

0 commit comments

Comments
 (0)