Skip to content

Commit 6e817d2

Browse files
Merge pull request #32 from step-security/fix/sl/pdpr-retry
Handle pdpr resource bulk updates
2 parents deeac71 + 473d14f commit 6e817d2

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

internal/stepsecurity-api/client.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ type APIClient struct {
6969
Customer string
7070
}
7171

72+
type HTTPRequestOpts func(req *http.Request)
73+
74+
func WithHttpHeaders(headers map[string]string) HTTPRequestOpts {
75+
return func(req *http.Request) {
76+
for key, val := range headers {
77+
req.Header.Set(key, val)
78+
}
79+
}
80+
}
81+
7282
func NewClient(baseURL, apiKey, customer string) (Client, error) {
7383
return &APIClient{
7484
HTTPClient: &http.Client{},
@@ -78,11 +88,15 @@ func NewClient(baseURL, apiKey, customer string) (Client, error) {
7888
}, nil
7989
}
8090

81-
func (c *APIClient) do(req *http.Request) ([]byte, error) {
91+
func (c *APIClient) do(req *http.Request, opts ...HTTPRequestOpts) ([]byte, error) {
8292
if req == nil {
8393
return nil, nil
8494
}
8595

96+
for _, opt := range opts {
97+
opt(req)
98+
}
99+
86100
req.Header.Set("Authorization", "Bearer "+c.APIKey)
87101
res, err := c.HTTPClient.Do(req)
88102
if err != nil {
@@ -112,7 +126,7 @@ func (c *APIClient) get(ctx context.Context, URI string) ([]byte, error) {
112126
return c.do(req)
113127
}
114128

115-
func (c *APIClient) update(ctx context.Context, URI string, payload any, method string) ([]byte, error) {
129+
func (c *APIClient) update(ctx context.Context, URI string, payload any, method string, opts ...HTTPRequestOpts) ([]byte, error) {
116130
reqBody, err := json.Marshal(payload)
117131
if err != nil {
118132
return nil, fmt.Errorf("failed to marshal config: %w", err)
@@ -122,11 +136,11 @@ func (c *APIClient) update(ctx context.Context, URI string, payload any, method
122136
return nil, fmt.Errorf("failed to create request: %w", err)
123137
}
124138
httpReq.Header.Set("Content-Type", "application/json")
125-
return c.do(httpReq)
139+
return c.do(httpReq, opts...)
126140
}
127141

128-
func (c *APIClient) post(ctx context.Context, URI string, payload any) ([]byte, error) {
129-
return c.update(ctx, URI, payload, "POST")
142+
func (c *APIClient) post(ctx context.Context, URI string, payload any, opts ...HTTPRequestOpts) ([]byte, error) {
143+
return c.update(ctx, URI, payload, "POST", opts...)
130144
}
131145

132146
func (c *APIClient) put(ctx context.Context, URI string, payload any) ([]byte, error) {

internal/stepsecurity-api/gh-policy-driven-prs.go

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"strings"
8+
"time"
79

10+
"github.com/hashicorp/go-uuid"
811
"github.com/hashicorp/terraform-plugin-log/tflog"
912
)
1013

@@ -208,10 +211,68 @@ func (c *APIClient) CreatePolicyDrivenPRPolicy(ctx context.Context, createReques
208211

209212
func (c *APIClient) updateConfigForRepo(ctx context.Context, owner string, repo string, config policyDrivenPRConfigOptions) error {
210213
URI := fmt.Sprintf("%s/v1/github/%s/%s/policy-driven-pr/configs", c.BaseURL, owner, repo)
211-
if _, err := c.post(ctx, URI, config); err != nil {
214+
215+
uuid, err := uuid.GenerateUUID()
216+
if err != nil {
217+
return fmt.Errorf("error getting async event id: %w", err)
218+
}
219+
httpHeaders := map[string]string{
220+
"x-async-event-id": uuid,
221+
}
222+
223+
// First attempt
224+
_, err = c.post(ctx, URI, config, WithHttpHeaders(httpHeaders))
225+
if err == nil {
226+
return nil
227+
}
228+
229+
// If it's not a 503, fail immediately
230+
if !strings.Contains(err.Error(), "status: 503") {
212231
return fmt.Errorf("failed to update config for repo: %w", err)
213232
}
214-
return nil
233+
234+
// when status = 503 retry same request until it is completed or retry count is exhausted
235+
timeoutTimer := time.NewTimer(3 * time.Minute)
236+
periodicTicker := time.NewTicker(10 * time.Second) // poll for every 10 seconds
237+
defer func() {
238+
timeoutTimer.Stop()
239+
periodicTicker.Stop()
240+
}()
241+
242+
type retryResp struct {
243+
Status int `json:"status"`
244+
State string `json:"state"` // in_progress, completed
245+
Data any `json:"data"`
246+
}
247+
248+
for {
249+
select {
250+
case <-ctx.Done():
251+
return fmt.Errorf("context cancelled while retrying update config for repo: %w", ctx.Err())
252+
case <-timeoutTimer.C:
253+
return fmt.Errorf("timeout exceeded while updating config for repo")
254+
case <-periodicTicker.C:
255+
response, err1 := c.post(ctx, URI, config, WithHttpHeaders(httpHeaders))
256+
if err1 != nil {
257+
return fmt.Errorf("failed to update config for repo: %w", err)
258+
}
259+
260+
var resp retryResp
261+
err2 := json.Unmarshal(response, &resp)
262+
if err2 != nil {
263+
return fmt.Errorf("failed to update config for repo: %w", err)
264+
}
265+
266+
if resp.State == "completed" {
267+
// check if status code is not 200 and return original error
268+
if resp.Status != 200 {
269+
return fmt.Errorf("failed to update config for repo: %w", err)
270+
}
271+
return nil
272+
}
273+
}
274+
}
275+
215276
}
216277

217278
func (c *APIClient) GetPolicyDrivenPRPolicy(ctx context.Context, owner string, repos []string) (*PolicyDrivenPRPolicy, error) {

0 commit comments

Comments
 (0)