Skip to content

Commit a8631f1

Browse files
committed
feat: add retry with exponential backoff for network requests
Fixes #5
1 parent 00df31a commit a8631f1

File tree

5 files changed

+255
-42
lines changed

5 files changed

+255
-42
lines changed

internal/github/client.go

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"strings"
77

88
gogithub "github.com/google/go-github/v68/github"
9+
10+
"github.com/omarkohl/jip/internal/retry"
911
)
1012

1113
// Service defines the GitHub operations needed by the send pipeline.
@@ -75,12 +77,17 @@ type UpdatePROpts struct {
7577

7678
// CreatePR creates a new pull request and returns its info.
7779
func (c *Client) CreatePR(head, base, title, body string, draft bool) (*PRInfo, error) {
78-
pr, _, err := c.gh.PullRequests.Create(context.Background(), c.owner, c.repo, &gogithub.NewPullRequest{
79-
Title: &title,
80-
Head: &head,
81-
Base: &base,
82-
Body: &body,
83-
Draft: &draft,
80+
var pr *gogithub.PullRequest
81+
err := retry.Do(func() error {
82+
var apiErr error
83+
pr, _, apiErr = c.gh.PullRequests.Create(context.Background(), c.owner, c.repo, &gogithub.NewPullRequest{
84+
Title: &title,
85+
Head: &head,
86+
Base: &base,
87+
Body: &body,
88+
Draft: &draft,
89+
})
90+
return apiErr
8491
})
8592
if err != nil {
8693
return nil, fmt.Errorf("creating PR: %w", err)
@@ -109,7 +116,10 @@ func (c *Client) UpdatePR(number int, opts UpdatePROpts) error {
109116
if opts.Base != nil {
110117
update.Base = &gogithub.PullRequestBranch{Ref: opts.Base}
111118
}
112-
_, _, err := c.gh.PullRequests.Edit(context.Background(), c.owner, c.repo, number, update)
119+
err := retry.Do(func() error {
120+
_, _, apiErr := c.gh.PullRequests.Edit(context.Background(), c.owner, c.repo, number, update)
121+
return apiErr
122+
})
113123
if err != nil {
114124
return fmt.Errorf("updating PR #%d: %w", number, err)
115125
}
@@ -118,8 +128,11 @@ func (c *Client) UpdatePR(number int, opts UpdatePROpts) error {
118128

119129
// CommentOnPR posts a comment on a pull request.
120130
func (c *Client) CommentOnPR(number int, body string) error {
121-
_, _, err := c.gh.Issues.CreateComment(context.Background(), c.owner, c.repo, number, &gogithub.IssueComment{
122-
Body: &body,
131+
err := retry.Do(func() error {
132+
_, _, apiErr := c.gh.Issues.CreateComment(context.Background(), c.owner, c.repo, number, &gogithub.IssueComment{
133+
Body: &body,
134+
})
135+
return apiErr
123136
})
124137
if err != nil {
125138
return fmt.Errorf("commenting on PR #%d: %w", number, err)
@@ -129,7 +142,12 @@ func (c *Client) CommentOnPR(number int, body string) error {
129142

130143
// GetAuthenticatedUser returns the login of the authenticated user.
131144
func (c *Client) GetAuthenticatedUser() (string, error) {
132-
user, _, err := c.gh.Users.Get(context.Background(), "")
145+
var user *gogithub.User
146+
err := retry.Do(func() error {
147+
var apiErr error
148+
user, _, apiErr = c.gh.Users.Get(context.Background(), "")
149+
return apiErr
150+
})
133151
if err != nil {
134152
return "", fmt.Errorf("getting authenticated user: %w", err)
135153
}
@@ -138,8 +156,11 @@ func (c *Client) GetAuthenticatedUser() (string, error) {
138156

139157
// RequestReviewers adds reviewers to a pull request.
140158
func (c *Client) RequestReviewers(number int, reviewers []string) error {
141-
_, _, err := c.gh.PullRequests.RequestReviewers(context.Background(), c.owner, c.repo, number, gogithub.ReviewersRequest{
142-
Reviewers: reviewers,
159+
err := retry.Do(func() error {
160+
_, _, apiErr := c.gh.PullRequests.RequestReviewers(context.Background(), c.owner, c.repo, number, gogithub.ReviewersRequest{
161+
Reviewers: reviewers,
162+
})
163+
return apiErr
143164
})
144165
if err != nil {
145166
return fmt.Errorf("requesting reviewers on PR #%d: %w", number, err)

internal/github/pr.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"io"
88
"net/http"
99
"strings"
10+
11+
"github.com/omarkohl/jip/internal/retry"
1012
)
1113

1214
// PRInfo holds the essential fields of a pull request.
@@ -59,15 +61,32 @@ func (c *Client) LookupPRsByBranch(branches []string) (map[string]*PRInfo, error
5961
req.Header.Set("Authorization", "bearer "+c.token)
6062
req.Header.Set("Content-Type", "application/json")
6163

62-
resp, err := http.DefaultClient.Do(req)
63-
if err != nil {
64-
return nil, fmt.Errorf("sending request: %w", err)
65-
}
66-
defer func() { _ = resp.Body.Close() }()
64+
var resp *http.Response
65+
var rawBody []byte
66+
err = retry.Do(func() error {
67+
// Reset the request body for each attempt.
68+
req.Body = io.NopCloser(bytes.NewReader(body))
69+
70+
var doErr error
71+
resp, doErr = http.DefaultClient.Do(req)
72+
if doErr != nil {
73+
return doErr
74+
}
75+
76+
rawBody, doErr = io.ReadAll(resp.Body)
77+
_ = resp.Body.Close()
78+
if doErr != nil {
79+
return doErr
80+
}
6781

68-
rawBody, err := io.ReadAll(resp.Body)
82+
// Retry on server errors (5xx); don't retry client errors (4xx).
83+
if resp.StatusCode >= 500 {
84+
return fmt.Errorf("GitHub API returned %d: %s", resp.StatusCode, string(rawBody))
85+
}
86+
return nil
87+
})
6988
if err != nil {
70-
return nil, fmt.Errorf("reading response: %w", err)
89+
return nil, fmt.Errorf("sending request: %w", err)
7190
}
7291

7392
if resp.StatusCode != 200 {

internal/jj/runner.go

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"fmt"
55
"os/exec"
66
"strings"
7+
8+
"github.com/omarkohl/jip/internal/retry"
79
)
810

911
// logTemplate is the jj template that outputs one JSON object per line.
@@ -127,32 +129,36 @@ func (r *realRunner) GitRemoteList() ([]byte, error) {
127129
}
128130

129131
func (r *realRunner) GitFetch(remote string) error {
130-
args := []string{"git", "fetch", "-R", r.repoDir, "--remote", remote}
131-
cmd := exec.Command("jj", args...)
132-
out, err := cmd.CombinedOutput()
133-
if err != nil {
134-
return fmt.Errorf("jj git fetch: %w\n%s", err, strings.TrimSpace(string(out)))
135-
}
136-
return nil
132+
return retry.Do(func() error {
133+
args := []string{"git", "fetch", "-R", r.repoDir, "--remote", remote}
134+
cmd := exec.Command("jj", args...)
135+
out, err := cmd.CombinedOutput()
136+
if err != nil {
137+
return fmt.Errorf("jj git fetch: %w\n%s", err, strings.TrimSpace(string(out)))
138+
}
139+
return nil
140+
})
137141
}
138142

139143
func (r *realRunner) GitPush(bookmarks []string, allowNew bool, remote string) error {
140-
args := []string{"git", "push", "-R", r.repoDir}
141-
if remote != "" {
142-
args = append(args, "--remote", remote)
143-
}
144-
for _, b := range bookmarks {
145-
args = append(args, "-b", b)
146-
}
147-
if allowNew {
148-
args = append(args, "--allow-new")
149-
}
150-
cmd := exec.Command("jj", args...)
151-
out, err := cmd.CombinedOutput()
152-
if err != nil {
153-
return fmt.Errorf("jj git push: %w\n%s", err, strings.TrimSpace(string(out)))
154-
}
155-
return nil
144+
return retry.Do(func() error {
145+
args := []string{"git", "push", "-R", r.repoDir}
146+
if remote != "" {
147+
args = append(args, "--remote", remote)
148+
}
149+
for _, b := range bookmarks {
150+
args = append(args, "-b", b)
151+
}
152+
if allowNew {
153+
args = append(args, "--allow-new")
154+
}
155+
cmd := exec.Command("jj", args...)
156+
out, err := cmd.CombinedOutput()
157+
if err != nil {
158+
return fmt.Errorf("jj git push: %w\n%s", err, strings.TrimSpace(string(out)))
159+
}
160+
return nil
161+
})
156162
}
157163

158164
func (r *realRunner) Interdiff(from, to string) (string, error) {

internal/retry/retry.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package retry
2+
3+
import (
4+
"math"
5+
"math/rand/v2"
6+
"time"
7+
)
8+
9+
// config holds retry parameters.
10+
type config struct {
11+
maxAttempts int
12+
initialBackoff time.Duration
13+
multiplier float64
14+
maxBackoff time.Duration
15+
}
16+
17+
// Option configures retry behavior.
18+
type Option func(*config)
19+
20+
// WithMaxAttempts sets the maximum number of attempts (default 3).
21+
func WithMaxAttempts(n int) Option {
22+
return func(c *config) { c.maxAttempts = n }
23+
}
24+
25+
// WithInitialBackoff sets the initial backoff duration (default 1s).
26+
func WithInitialBackoff(d time.Duration) Option {
27+
return func(c *config) { c.initialBackoff = d }
28+
}
29+
30+
// WithMultiplier sets the backoff multiplier (default 2.0).
31+
func WithMultiplier(m float64) Option {
32+
return func(c *config) { c.multiplier = m }
33+
}
34+
35+
// WithMaxBackoff sets the maximum backoff duration (default 30s).
36+
func WithMaxBackoff(d time.Duration) Option {
37+
return func(c *config) { c.maxBackoff = d }
38+
}
39+
40+
// Do calls fn up to maxAttempts times, sleeping with exponential backoff
41+
// and jitter between attempts. Returns the last error if all attempts fail.
42+
func Do(fn func() error, opts ...Option) error {
43+
cfg := config{
44+
maxAttempts: 3,
45+
initialBackoff: 1 * time.Second,
46+
multiplier: 2.0,
47+
maxBackoff: 30 * time.Second,
48+
}
49+
for _, o := range opts {
50+
o(&cfg)
51+
}
52+
53+
var err error
54+
for attempt := range cfg.maxAttempts {
55+
err = fn()
56+
if err == nil {
57+
return nil
58+
}
59+
if attempt < cfg.maxAttempts-1 {
60+
backoff := float64(cfg.initialBackoff) * math.Pow(cfg.multiplier, float64(attempt))
61+
if backoff > float64(cfg.maxBackoff) {
62+
backoff = float64(cfg.maxBackoff)
63+
}
64+
// Add jitter: 50-100% of the computed backoff.
65+
jittered := time.Duration(backoff * (0.5 + rand.Float64()*0.5))
66+
time.Sleep(jittered)
67+
}
68+
}
69+
return err
70+
}

internal/retry/retry_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package retry
2+
3+
import (
4+
"errors"
5+
"testing"
6+
"time"
7+
)
8+
9+
func TestDoSucceedsImmediately(t *testing.T) {
10+
calls := 0
11+
err := Do(func() error {
12+
calls++
13+
return nil
14+
}, WithInitialBackoff(time.Millisecond))
15+
if err != nil {
16+
t.Fatalf("unexpected error: %v", err)
17+
}
18+
if calls != 1 {
19+
t.Fatalf("expected 1 call, got %d", calls)
20+
}
21+
}
22+
23+
func TestDoRetriesOnError(t *testing.T) {
24+
calls := 0
25+
err := Do(func() error {
26+
calls++
27+
if calls < 3 {
28+
return errors.New("transient")
29+
}
30+
return nil
31+
}, WithMaxAttempts(3), WithInitialBackoff(time.Millisecond))
32+
if err != nil {
33+
t.Fatalf("unexpected error: %v", err)
34+
}
35+
if calls != 3 {
36+
t.Fatalf("expected 3 calls, got %d", calls)
37+
}
38+
}
39+
40+
func TestDoExhaustsAttempts(t *testing.T) {
41+
calls := 0
42+
sentinel := errors.New("persistent")
43+
err := Do(func() error {
44+
calls++
45+
return sentinel
46+
}, WithMaxAttempts(4), WithInitialBackoff(time.Millisecond))
47+
if !errors.Is(err, sentinel) {
48+
t.Fatalf("expected sentinel error, got: %v", err)
49+
}
50+
if calls != 4 {
51+
t.Fatalf("expected 4 calls, got %d", calls)
52+
}
53+
}
54+
55+
func TestDoRespectsBackoff(t *testing.T) {
56+
start := time.Now()
57+
calls := 0
58+
_ = Do(func() error {
59+
calls++
60+
if calls < 3 {
61+
return errors.New("fail")
62+
}
63+
return nil
64+
}, WithMaxAttempts(3), WithInitialBackoff(10*time.Millisecond), WithMultiplier(1.0))
65+
elapsed := time.Since(start)
66+
// With 2 sleeps of ~10ms (jittered to 5-10ms each), expect at least 8ms total.
67+
if elapsed < 8*time.Millisecond {
68+
t.Fatalf("expected backoff delay, but elapsed was only %v", elapsed)
69+
}
70+
}
71+
72+
func TestDoSucceedsAfterOneFailure(t *testing.T) {
73+
calls := 0
74+
err := Do(func() error {
75+
calls++
76+
if calls == 1 {
77+
return errors.New("first fail")
78+
}
79+
return nil
80+
}, WithMaxAttempts(2), WithInitialBackoff(time.Millisecond))
81+
if err != nil {
82+
t.Fatalf("unexpected error: %v", err)
83+
}
84+
if calls != 2 {
85+
t.Fatalf("expected 2 calls, got %d", calls)
86+
}
87+
}
88+
89+
func TestDoSingleAttempt(t *testing.T) {
90+
sentinel := errors.New("fail")
91+
err := Do(func() error {
92+
return sentinel
93+
}, WithMaxAttempts(1))
94+
if !errors.Is(err, sentinel) {
95+
t.Fatalf("expected sentinel error, got: %v", err)
96+
}
97+
}

0 commit comments

Comments
 (0)